diff --git a/src/sql_jsc/jsc.rs b/src/sql_jsc/jsc.rs index a5c5ed99526..c719f02cb1d 100644 --- a/src/sql_jsc/jsc.rs +++ b/src/sql_jsc/jsc.rs @@ -1,12 +1,12 @@ //! `bun_jsc` re-export façade for the SQL bindings. //! //! All core handle types (`JSValue`, `JSGlobalObject`, `CallFrame`, `JsError`, -//! `JsResult`, `JSObject`, `JSCell`, `JSType`, [`VirtualMachine`], +//! `JsResult`, `JSObject`, `JSType`, [`VirtualMachine`], //! [`EventLoop`], [`KeepAlive`], …) are **re-exported from `bun_jsc` / //! `bun_io`** so the `#[bun_jsc::JsClass]` / `#[bun_jsc::host_fn]` proc-macros //! see identical types. SQL-specific helpers that `bun_jsc` doesn't expose at -//! this tier are provided as extension traits ([`JSGlobalObjectSqlExt`], -//! [`VirtualMachineSqlExt`], [`EventLoopSqlExt`]). +//! this tier are provided as extension traits ([`VirtualMachineSqlExt`], +//! [`EventLoopSqlExt`]). //! //! [`RareData`] here is the **per-VM SQL state** (`mysql_context` / //! `postgresql_context`) that `bun_runtime::jsc_hooks::RuntimeState` owns by @@ -27,10 +27,9 @@ use core::ptr::NonNull; // ────────────────────────────────────────────────────────────────────────── pub use bun_jsc::{ - ArrayBuffer, CallFrame, CoerceTo, ErrorBuilder, ErrorCode, ExternColumnIdentifier, - ExternColumnIdentifierValue, GlobalRef, JSArrayIterator, JSCell, JSGlobalObject, JSObject, - JSType, JSValue, JsCell, JsError, JsRef, JsResult, MarkedArgumentBuffer, StringJsc, - StrongOptional, ThrowFmtArgs, ZigStringJsc, bun_string_jsc, host_fn, + ArrayBuffer, CallFrame, ErrorBuilder, ErrorCode, ExternColumnIdentifier, GlobalRef, + JSArrayIterator, JSGlobalObject, JSObject, JSType, JSValue, JsCell, JsError, JsRef, JsResult, + MarkedArgumentBuffer, StringJsc, StrongOptional, bun_string_jsc, host_fn, }; /// Re-export — `bun_jsc` now defines `IntegerRange` at its crate root and the @@ -146,49 +145,6 @@ pub(crate) fn create_bun_socket_error_to_js( } } -// ────────────────────────────────────────────────────────────────────────── -// JSGlobalObject — SQL-specific extension surface. -// ────────────────────────────────────────────────────────────────────────── - -/// SQL-side helpers on `JSGlobalObject` not provided by `bun_jsc` (or where -/// the SQL bindings need a slightly different signature). -pub(crate) trait JSGlobalObjectSqlExt { - fn err_out_of_range<'a>(&'a self, args: core::fmt::Arguments<'a>) -> ErrorBuilder<'a>; - fn throw_invalid_arguments_fmt(&self, args: core::fmt::Arguments<'_>) -> JsResult; - /// `globalObject.bunVM()` — `bun_jsc::JSGlobalObject::bun_vm()` returns - /// `&mut VirtualMachine`; this `&`-receiver form is for SQL callsites that - /// only need shared access. - fn sql_vm(&self) -> &VirtualMachine; - fn sql_vm_ptr(&self) -> *mut VirtualMachine; - - // PORT NOTE: `validate_integer_range` / `validate_big_int_range` / - // `gregorian_date_time_to_ms` were duplicated here while gated in - // `bun_jsc`; all three are now inherent on `bun_jsc::JSGlobalObject`, so - // the trait copies are removed (inherent methods always win in - // resolution, so the trait versions were dead code anyway). -} - -impl JSGlobalObjectSqlExt for JSGlobalObject { - #[inline] - fn err_out_of_range<'a>(&'a self, args: core::fmt::Arguments<'a>) -> ErrorBuilder<'a> { - self.err(ErrorCode::OUT_OF_RANGE, args) - } - #[inline] - fn throw_invalid_arguments_fmt(&self, args: core::fmt::Arguments<'_>) -> JsResult { - Err(self.throw(args)) - } - #[inline] - fn sql_vm(&self) -> &VirtualMachine { - // `JSGlobalObject::bun_vm` is the canonical safe accessor (single - // audited deref in bun_jsc); the VM is a process-lifetime singleton. - self.bun_vm() - } - #[inline] - fn sql_vm_ptr(&self) -> *mut VirtualMachine { - JSC__JSGlobalObject__bunVM(self).cast::() - } -} - // ────────────────────────────────────────────────────────────────────────── // VirtualMachine / EventLoop — direct re-exports from bun_jsc. // @@ -678,24 +634,13 @@ pub(crate) struct JSFunction { /// (`extern "sysv64"` on win-x64, `extern "C"` elsewhere). Re-exported from /// `bun_jsc` so the cfg-split lives in one place. pub use bun_jsc::host_fn::JsHostFn as JSHostFn; -pub type JSHostFnZig = fn(&JSGlobalObject, &CallFrame) -> JsResult; pub(crate) trait IntoJSHostFn: Sized { fn into_js_host_fn(self) -> JSHostFn; } #[doc(hidden)] -pub(crate) struct HostFnRaw; -#[doc(hidden)] pub(crate) struct HostFnResult; -#[doc(hidden)] -pub(crate) struct HostFnPlain; -impl IntoJSHostFn for JSHostFn { - #[inline] - fn into_js_host_fn(self) -> JSHostFn { - self - } -} // `jsc_host_abi!` can't express a generic `where` clause, so cfg-split the // thunk body manually (sysv64 on win-x64, C elsewhere — matches `JSHostFn`). // The where-clause is bracketed to avoid `tt`-muncher ambiguity against `{`. @@ -742,32 +687,6 @@ where thunk:: } } -impl IntoJSHostFn for F -where - F: Fn(&JSGlobalObject, &CallFrame) -> JSValue + Copy + 'static, -{ - fn into_js_host_fn(self) -> JSHostFn { - debug_assert_eq!( - core::mem::size_of::(), - 0, - "IntoJSHostFn: expected fn item (ZST)" - ); - let _ = self; - sql_jsc_host_thunk! { - thunk(g: *mut JSGlobalObject, c: *mut CallFrame) -> JSValue - where [F: Fn(&JSGlobalObject, &CallFrame) -> JSValue + Copy + 'static] - { - let f: F = bun_core::ffi::conjure_zst::(); - // JSC passes live non-null pointers; both outlive the host-fn - // call (the `ParentRef` invariant). Safe `Deref` recovers `&T`. - let global = bun_ptr::ParentRef::from(NonNull::new(g).expect("JSC host fn: global non-null")); - let frame = bun_ptr::ParentRef::from(NonNull::new(c).expect("JSC host fn: callframe non-null")); - f(&global, &frame) - } - } - thunk:: - } -} #[repr(u8)] #[derive(Clone, Copy, Default)] @@ -828,8 +747,7 @@ macro_rules! put_host_functions { } impl JSFunction { - /// Accepts either a raw [`JSHostFn`] (C-ABI) or a safe Rust - /// `fn(&JSGlobalObject, &CallFrame) -> JSValue` / `-> JsResult` + /// Accepts a safe Rust `fn(&JSGlobalObject, &CallFrame) -> JsResult` /// via [`IntoJSHostFn`] (Zig: `jsc.toJSHostFn(fn)`). pub(crate) fn create>( global: &JSGlobalObject, @@ -852,40 +770,6 @@ impl JSFunction { } } -// ────────────────────────────────────────────────────────────────────────── -// CallFrame helpers — `bun_jsc::ArgumentsSlice` exists; this local variant -// keeps the `&VirtualMachine` (local view) signature the SQL callsites use. -// ────────────────────────────────────────────────────────────────────────── - -pub mod call_frame { - use super::*; - /// `Node.ArgumentsSlice` — cursor over a `&[JSValue]` (CallFrame.zig:289). - pub(crate) struct ArgumentsSlice<'a> { - remaining: &'a [JSValue], - _vm: *const c_void, - } - impl<'a> ArgumentsSlice<'a> { - /// Generic over the VM handle so it accepts both the local - /// [`VirtualMachine`] and `bun_jsc`'s (callers pass `global.bun_vm()`, - /// which returns a raw `*mut VirtualMachineRef`). The VM is not - /// dereferenced — it's only carried for API parity with the Zig - /// `Node.ArgumentsSlice` shape — so it's accepted by-value and dropped. - pub(crate) fn init(_vm: V, slice: &'a [JSValue]) -> Self { - Self { - remaining: slice, - _vm: core::ptr::null(), - } - } - /// Zig `nextEat` (CallFrame.zig) — return the head **and** advance. - #[inline] - pub(crate) fn next_eat(&mut self) -> Option { - let (first, rest) = self.remaining.split_first()?; - self.remaining = rest; - Some(*first) - } - } -} - // ────────────────────────────────────────────────────────────────────────── // MarkedArgumentBuffer::run — C++-side trampoline. `bun_jsc::MarkedArgumentBuffer` // exposes `new(f)`; the SQL callsites use the lower-level `run(ctx, fn_ptr)` @@ -910,21 +794,56 @@ impl SslCtxCache { } } -// ────────────────────────────────────────────────────────────────────────── -// extern "C" — **C++** JSC bindings (src/jsc/bindings/bindings.cpp) used by -// the extension traits above. No Rust-defined symbols are declared here; all -// `bun_runtime` cross-calls go through [`SqlRuntimeHooks`] so the compiler -// type-checks both sides at the registration site. -// ────────────────────────────────────────────────────────────────────────── -unsafe extern "C" { - // JSValue — by-value `JSValue` (encoded NaN-boxed u64) + scalar args; the - // C++ side reads no caller memory and upholds no invariants the caller must - // discharge, so these are `safe fn`. - - // JSGlobalObject — `&JSGlobalObject` is ABI-identical to a non-null - // `*const JSGlobalObject`; the reference type discharges the validity - // precondition, so `safe fn`. Returned pointer is opaque (caller derefs - // under its own SAFETY obligation). - safe fn JSC__JSGlobalObject__bunVM(this: &JSGlobalObject) -> *mut c_void; - +/// Shared tail of Postgres `setupTLS` / MySQL `upgradeToTLS`: adopt a +/// connected plain-TCP socket into `tls_group` as `kind`, attach an `SSL*` +/// from `ssl_ctx` (SNI from `tls_config.server_name()`), point the new +/// socket's ext slot at `ext_ptr`, hand the resulting TLS socket to +/// `install_socket`, then kick the TLS handshake. Returns `false` if adoption +/// failed (the caller maps that to its own error path). +/// +/// The ext slot is `Option>` — the Rust layout-equivalent of Zig's +/// 8-byte null-niche `?*T`. Using `Option<*mut T>` would request 16 bytes +/// (separate discriminant) and desync with the trampoline reader +/// (`uws_handlers.rs`), which reads the slot as `Option>`. +/// +/// # Safety +/// - `raw` must be a live, connected `us_socket_t*`; adoption invalidates it +/// (the C side may realloc and return a different pointer). +/// - `tls_config.server_name()` must be null or a NUL-terminated C string that +/// outlives the handshake. +pub(crate) unsafe fn adopt_socket_tls( + raw: *mut bun_uws::us_socket_t, + tls_group: &mut bun_uws::SocketGroup, + kind: bun_uws::SocketKind, + ssl_ctx: &mut bun_uws::SslCtx, + tls_config: &api::server_config::SSLConfig, + ext_ptr: *mut T, + install_socket: impl FnOnce(bun_uws::AnySocket), +) -> bool { + let server_name = tls_config.server_name(); + // SAFETY: caller contract — `server_name` is null or NUL-terminated. + let sni = (!server_name.is_null()).then(|| unsafe { bun_core::ffi::cstr(server_name) }); + let ext_size = core::mem::size_of::>>() as i32; + + // SAFETY: `raw` is a live connected `us_socket_t*` (caller contract); + // adopt_tls may realloc and return a different ptr. + let Some(new_socket) = + (unsafe { &mut *raw }).adopt_tls(tls_group, kind, ssl_ctx, sni, ext_size, ext_size) + else { + return false; + }; + let new_socket = new_socket.as_ptr(); + // SAFETY: `new_socket` is a live us_socket_t freshly returned by + // `adopt_tls`; its ext slot is sized for `Option>` above. One + // `&mut` reborrow drives both safe inherent methods (`ext` / + // `start_tls_handshake`). + let sock = unsafe { &mut *new_socket }; + *sock.ext::>>() = NonNull::new(ext_ptr); + install_socket(bun_uws::AnySocket::SocketTls(bun_uws::SocketTLS { + socket: bun_uws::InternalSocket::Connected(new_socket), + })); + // ext is repointed and the owner's socket field swapped; safe to kick the + // handshake (any dispatch lands in the new owner). + sock.start_tls_handshake(); + true } diff --git a/src/sql_jsc/lib.rs b/src/sql_jsc/lib.rs index 973229ea957..99b8bd61157 100644 --- a/src/sql_jsc/lib.rs +++ b/src/sql_jsc/lib.rs @@ -20,11 +20,15 @@ pub mod shared { #[path = "CachedStructure.rs"] pub mod cached_structure; + pub mod connection_args; + pub mod datetime_text; #[path = "ObjectIterator.rs"] pub mod object_iterator; + pub mod query_args; + #[path = "QueryBindingIterator.rs"] pub mod query_binding_iterator; diff --git a/src/sql_jsc/mysql/JSMySQLConnection.rs b/src/sql_jsc/mysql/JSMySQLConnection.rs index 6fb2f49287f..c831f166cbf 100644 --- a/src/sql_jsc/mysql/JSMySQLConnection.rs +++ b/src/sql_jsc/mysql/JSMySQLConnection.rs @@ -4,12 +4,11 @@ use core::ffi::c_void; use crate::jsc::{ CallFrame, EventLoopSqlExt as _, EventLoopTimer, EventLoopTimerState, EventLoopTimerTag, GlobalRef, HasAutoFlush, JSGlobalObject, JSValue, JsCell, JsRef, JsResult, KeepAlive, - VirtualMachine, VirtualMachineSqlExt as _, api::server_config::SSLConfig, - codegen::js_mysql_connection as js, webcore::AutoFlusher, + VirtualMachine, VirtualMachineSqlExt as _, codegen::js_mysql_connection as js, + webcore::AutoFlusher, }; use crate::shared::CachedStructure; -use bun_boringssl_sys as boringssl; -use bun_core::strings; +use crate::shared::connection_args; use bun_core::{TimespecMockMode, timespec}; use bun_ptr::{AsCtxPtr, BackRef, ParentRef}; use bun_sql::mysql::MySQLQueryResult; @@ -472,120 +471,37 @@ impl JSMySQLConnection { // SAFETY: JS-thread only; short-lived `&mut` to the singleton VM via raw ptr, // no other live borrow in this scope. let vm = global_object.bun_vm().as_mut(); - let arguments = callframe.arguments(); - let hostname_str = bun_core::OwnedString::new(arguments[0].to_bun_string(global_object)?); - let port = arguments[1].coerce::(global_object)?; - - let username_str = bun_core::OwnedString::new(arguments[2].to_bun_string(global_object)?); - let password_str = bun_core::OwnedString::new(arguments[3].to_bun_string(global_object)?); - let database_str = bun_core::OwnedString::new(arguments[4].to_bun_string(global_object)?); - // TODO: update this to match MySQL. - let ssl_mode: SSLMode = match arguments[5].to_int32() { - 0 => SSLMode::Disable, - 1 => SSLMode::Prefer, - 2 => SSLMode::Require, - 3 => SSLMode::VerifyCa, - 4 => SSLMode::VerifyFull, - _ => SSLMode::Disable, + // Args 0..=14 (hostname/port/credentials/sslMode/tls/options/path/ + // callbacks/timeouts) are decoded by the shared helper, which also + // builds the TLS `SSL_CTX` and returns it inside the `args.tls` + // errdefer guard. Ownership passes to `MySQLConnection.init` once + // `Box::new` succeeds — `into_inner` disarms the guard at that point so + // the connect-fail path (which `deref()`s the connection) doesn't + // double-free. + let Some(args) = connection_args::parse::(vm, global_object, callframe)? else { + return Ok(JSValue::ZERO); }; - - let tls_object = arguments[6]; - - let mut tls_config: SSLConfig = SSLConfig::default(); - let mut secure: Option<*mut uws::SslCtx> = None; - if ssl_mode != SSLMode::Disable { - tls_config = if tls_object.is_boolean() && tls_object.to_boolean() { - SSLConfig::default() - } else if tls_object.is_object() { - match SSLConfig::from_js(&mut *vm, global_object, tls_object) { - Ok(Some(c)) => c, - Ok(None) => SSLConfig::default(), - Err(_) => return Ok(JSValue::ZERO), - } - } else { - return Err(global_object - .throw_invalid_arguments(format_args!("tls must be a boolean or an object"))); - }; - - if global_object.has_exception() { - drop(tls_config); - return Ok(JSValue::ZERO); - } - - // We always request the cert so we can verify it and also we manually - // abort the connection if the hostname doesn't match. Built here so - // CA/cert errors throw synchronously, applied later by upgradeToTLS. - // Goes through the per-VM weak `SSLContextCache` so every pooled - // connection / reconnect shares one `SSL_CTX*` per distinct config. - let mut err = uws::create_bun_socket_error_t::none; - secure = vm - .ssl_ctx_cache() - .get_or_create_opts(&tls_config.as_usockets_for_client_verification(), &mut err); - if secure.is_none() { - drop(tls_config); - return Err( - global_object.throw_value(crate::jsc::create_bun_socket_error_to_js( - err, - global_object, - )), - ); - } - } - // Covers `try arguments[7/8].toBunString()` and the null-byte rejection - // below. Ownership passes to `MySQLConnection.init` once `Box::new` - // succeeds — we null the locals at that point so the connect-fail path - // (which `deref()`s the connection) doesn't double-free. - let tls_guard = scopeguard::guard((secure, tls_config), |(s, cfg)| { - if let Some(s) = s { - // SAFETY: secure was created by ssl_ctx_cache; we own one ref until transferred. - unsafe { boringssl::SSL_CTX_free(s) }; - } - drop(cfg); - }); - - let options_str = bun_core::OwnedString::new(arguments[7].to_bun_string(global_object)?); - let path_str = bun_core::OwnedString::new(arguments[8].to_bun_string(global_object)?); + // MySQL-only argument; `args.use_unnamed_prepared_statements` is + // intentionally unused (MySQL doesn't support unnamed prepared + // statements). + let allow_public_key_retrieval = callframe.argument(15).to_boolean(); // PORT NOTE: Zig packed all five strings into one `StringBuilder`-owned // arena and handed `[]const u8` slices into it to `MySQLConnection.init`. // The Rust `init` takes `Box<[u8]>` per field (each separately owned), - // so we just copy each string into its own allocation. `options_buf` + // so each string moves (or copies) into its own allocation. `options_buf` // (the original arena handle, kept only so `cleanup()` could free it) // becomes an empty box. - let username: Box<[u8]> = Box::from(username_str.to_utf8_without_ref().slice()); - let password: Box<[u8]> = Box::from(password_str.to_utf8_without_ref().slice()); - let database: Box<[u8]> = Box::from(database_str.to_utf8_without_ref().slice()); - let options: Box<[u8]> = Box::from(options_str.to_utf8_without_ref().slice()); - let path: Box<[u8]> = Box::from(path_str.to_utf8_without_ref().slice()); + let username: Box<[u8]> = args.username.into_boxed_bytes(); + let password: Box<[u8]> = args.password.into_boxed_bytes(); + let database: Box<[u8]> = args.database.into_boxed_bytes(); + let options: Box<[u8]> = args.options.into_boxed_bytes(); + let path: Box<[u8]> = args.path.into_boxed_bytes(); let options_buf: Box<[u8]> = Box::default(); - // Reject null bytes in connection parameters to prevent protocol injection - // (null bytes act as field terminators in the MySQL wire protocol). - for (slice, msg) in [ - (&username[..], "username must not contain null bytes"), - (&password[..], "password must not contain null bytes"), - (&database[..], "database must not contain null bytes"), - (&path[..], "path must not contain null bytes"), - ] { - if !slice.is_empty() && strings::index_of_char(slice, 0).is_some() { - // tls_config / secure released by the guard above. - return Err(global_object.throw_invalid_arguments(format_args!("{msg}"))); - } - } - - let on_connect = arguments[9]; - let on_close = arguments[10]; - let idle_timeout = arguments[11].to_int32(); - let connection_timeout = arguments[12].to_int32(); - let max_lifetime = arguments[13].to_int32(); - let use_unnamed_prepared_statements = arguments[14].as_boolean(); - // MySQL doesn't support unnamed prepared statements - let _ = use_unnamed_prepared_statements; - let allow_public_key_retrieval = callframe.argument(15).to_boolean(); - // Ownership transferred into `ptr.connection`; disarm the errdefer so the // connect-fail `ptr.deref()` is the sole cleanup path from here on. - let (secure, tls_config) = scopeguard::ScopeGuard::into_inner(tls_guard); + let (secure, tls_config) = scopeguard::ScopeGuard::into_inner(args.tls); let ptr: *mut JSMySQLConnection = bun_core::heap::into_raw(Box::new(JSMySQLConnection { ref_count: Cell::new(1), @@ -601,13 +517,13 @@ impl JSMySQLConnection { options_buf, tls_config, secure, - ssl_mode, + args.ssl_mode, allow_public_key_retrieval, )), auto_flusher: JsCell::new(AutoFlusher::default()), - idle_timeout_interval_ms: u32::try_from(idle_timeout).expect("int cast"), - connection_timeout_ms: u32::try_from(connection_timeout).expect("int cast"), - max_lifetime_interval_ms: u32::try_from(max_lifetime).expect("int cast"), + idle_timeout_interval_ms: u32::try_from(args.idle_timeout).expect("int cast"), + connection_timeout_ms: u32::try_from(args.connection_timeout).expect("int cast"), + max_lifetime_interval_ms: u32::try_from(args.max_lifetime).expect("int cast"), timer: JsCell::new(EventLoopTimer::init_paused( EventLoopTimerTag::MySQLConnectionTimeout, )), @@ -622,7 +538,7 @@ impl JSMySQLConnection { let this = ParentRef::from(core::ptr::NonNull::new(ptr).expect("heap::into_raw non-null")); { - let hostname = hostname_str.to_utf8(); + let hostname = args.hostname.to_utf8(); // MySQL always opens plain TCP first; STARTTLS adopts into the TLS // group after the SSLRequest exchange. @@ -642,7 +558,7 @@ impl JSMySQLConnection { uws::DispatchKind::Mysql, None, hostname.slice(), - port, + args.port, ptr, false, ) @@ -668,8 +584,8 @@ impl JSMySQLConnection { js_value.ensure_still_alive(); this.js_value .with_mut(|r| r.set_strong(js_value, global_object)); - js::onconnect_set_cached(js_value, global_object, on_connect); - js::onclose_set_cached(js_value, global_object, on_close); + js::onconnect_set_cached(js_value, global_object, args.on_connect); + js::onclose_set_cached(js_value, global_object, args.on_close); Ok(js_value) } @@ -713,34 +629,28 @@ impl JSMySQLConnection { Ok(JSValue::UNDEFINED) } - fn consume_on_connect_callback(&self, global_object: &JSGlobalObject) -> Option { + /// `js_value` if the VM is not shutting down and the JS wrapper ref + /// still resolves; `None` otherwise. + #[inline] + fn live_js_value(&self) -> Option { if self.vm().is_shutting_down() { return None; } - if let Some(value) = self.js_value.get().try_get() { - return js::onconnect_take_cached(value, global_object); - } - None + self.js_value.get().try_get() + } + + fn consume_on_connect_callback(&self, global_object: &JSGlobalObject) -> Option { + js::onconnect_take_cached(self.live_js_value()?, global_object) } fn consume_on_close_callback(&self, global_object: &JSGlobalObject) -> Option { - if self.vm().is_shutting_down() { - return None; - } - if let Some(value) = self.js_value.get().try_get() { - return js::onclose_take_cached(value, global_object); - } - None + js::onclose_take_cached(self.live_js_value()?, global_object) } pub fn get_queries_array(&self) -> JSValue { - if self.vm().is_shutting_down() { - return JSValue::UNDEFINED; - } - if let Some(value) = self.js_value.get().try_get() { - return js::queries_get_cached(value).unwrap_or(JSValue::UNDEFINED); - } - JSValue::UNDEFINED + self.live_js_value() + .and_then(js::queries_get_cached) + .unwrap_or(JSValue::UNDEFINED) } #[inline] diff --git a/src/sql_jsc/mysql/JSMySQLQuery.rs b/src/sql_jsc/mysql/JSMySQLQuery.rs index b227b92d769..9b1b33caf82 100644 --- a/src/sql_jsc/mysql/JSMySQLQuery.rs +++ b/src/sql_jsc/mysql/JSMySQLQuery.rs @@ -3,8 +3,8 @@ use core::ptr::NonNull; use crate::jsc::codegen::{js_mysql_connection, js_mysql_query as js}; use crate::jsc::{ - self as jsc, CallFrame, JSGlobalObject, JSGlobalObjectSqlExt as _, JSValue, JsRef, JsResult, - VirtualMachine, VirtualMachineSqlExt as _, + self as jsc, CallFrame, JSGlobalObject, JSValue, JsRef, JsResult, VirtualMachine, + VirtualMachineSqlExt as _, }; use bun_jsc::JsCell; use bun_ptr::{AsCtxPtr, BackRef, ParentRef}; @@ -16,6 +16,7 @@ use bun_sql::shared::sql_query_result_mode::SQLQueryResultMode; use super::js_mysql_connection::MySQLConnection; use crate::mysql::protocol::any_mysql_error_jsc::mysql_error_to_js; use crate::postgres::command_tag_jsc::CommandTagJsc as _; +use crate::shared::query_args; // PORT NOTE: `my_sql_query` exports both the `MySQLQuery` *struct* and a // `declare_scope!`-generated `MySQLQuery` *static* (ScopedLogger). Importing // the name once pulls in both namespaces, so the `debug!` macro below resolves @@ -96,56 +97,19 @@ impl JSMySQLQuery { global_this: &JSGlobalObject, callframe: &CallFrame, ) -> JsResult { - let arguments = callframe.arguments(); - let mut args = jsc::call_frame::ArgumentsSlice::init(global_this.sql_vm(), arguments); - // defer args.deinit() — handled by Drop - let Some(query) = args.next_eat() else { - return Err(global_this.throw(format_args!("query must be a string"))); - }; - let Some(values) = args.next_eat() else { - return Err(global_this.throw(format_args!("values must be an array"))); - }; - - if !query.is_string() { - return Err(global_this.throw(format_args!("query must be a string"))); - } - - if values.js_type() != jsc::JSType::Array { - return Err(global_this.throw(format_args!("values must be an array"))); - } - - let pending_value: JSValue = args.next_eat().unwrap_or(JSValue::UNDEFINED); - let columns: JSValue = args.next_eat().unwrap_or(JSValue::UNDEFINED); - let js_bigint: JSValue = args.next_eat().unwrap_or(JSValue::FALSE); - let js_simple: JSValue = args.next_eat().unwrap_or(JSValue::FALSE); - - let bigint = js_bigint.is_boolean() && js_bigint.as_boolean(); - let simple = js_simple.is_boolean() && js_simple.as_boolean(); - if simple { - if values.get_length(global_this)? > 0 { - return Err(global_this - .throw_invalid_arguments(format_args!("simple query cannot have parameters"))); - } - if query.get_length(global_this)? >= i32::MAX as u64 { - return Err(global_this.throw_invalid_arguments(format_args!("query is too long"))); - } - } - if !pending_value.js_type().is_array_like() { - return Err(global_this.throw_invalid_argument_type("query", "pendingValue", "Array")); - } + let args = query_args::parse(global_this, callframe)?; let this_ptr = bun_core::heap::into_raw(Box::new(Self { this_value: JsCell::new(JsRef::empty()), ref_count: Cell::new(1), - // Stored with full write provenance for later `&mut *p` at use sites. - vm: BackRef::from( - NonNull::new(global_this.sql_vm_ptr()).expect("sql_vm_ptr() is non-null"), - ), + // JS-thread VM singleton with full write provenance (provenance + // comes from the thread-local `*mut` inside `as_mut`). + vm: BackRef::new_mut(global_this.bun_vm().as_mut()), global_object: BackRef::new(global_this), query: JsCell::new(MySQLQuery::init( - query.to_bun_string(global_this)?, - bigint, - simple, + args.query.to_bun_string(global_this)?, + args.bigint, + args.simple, )), })); // `heap::into_raw` is `Box::into_raw` — never null. Uniquely owned here @@ -157,10 +121,10 @@ impl JSMySQLQuery { this_value.ensure_still_alive(); this.this_value.with_mut(|v| v.set_weak(this_value)); - this.set_binding(values); - this.set_pending_value(pending_value); - if !columns.is_undefined() { - this.set_columns(columns); + this.set_binding(args.values); + this.set_pending_value(args.pending_value); + if !args.columns.is_undefined() { + this.set_columns(args.columns); } Ok(this_value) @@ -198,7 +162,7 @@ impl JSMySQLQuery { if !global_object.has_exception() { return Err(global_object.throw_value(mysql_error_to_js( global_object, - "failed to execute query", + b"failed to execute query", err, ))); } @@ -358,7 +322,7 @@ impl JSMySQLQuery { if let Some(err_) = self.global_object().try_take_exception() { self.reject_with_js_value(queries_array, err_); } else { - let instance = mysql_error_to_js(self.global_object(), "Failed to bind query", err); + let instance = mysql_error_to_js(self.global_object(), b"Failed to bind query", err); instance.ensure_still_alive(); self.reject_with_js_value(queries_array, instance); } @@ -391,7 +355,7 @@ impl JSMySQLQuery { if js_error.is_empty() { js_error = mysql_error_to_js( self.global_object(), - "Query failed", + b"Query failed", AnyMySQLError::Error::UnknownError, ); } @@ -471,7 +435,7 @@ impl JSMySQLQuery { // Rust we throw for side-effect and map to the enum variant. let _ = global_object.throw_value(mysql_error_to_js( global_object, - "failed to execute query", + b"failed to execute query", err, )); } @@ -523,83 +487,57 @@ impl JSMySQLQuery { self.query.with_mut(|q| q.mark_as_prepared()); } + /// `this_value` if the VM is not shutting down and the JS wrapper ref + /// still resolves; `None` otherwise. #[inline] - pub fn set_pending_value(&self, result: JSValue) { + fn live_this_value(&self) -> Option { if self.vm().is_shutting_down() { - return; + return None; } - if let Some(value) = self.this_value.get().try_get() { + self.this_value.get().try_get() + } + + #[inline] + pub fn set_pending_value(&self, result: JSValue) { + if let Some(value) = self.live_this_value() { js::pending_value_set_cached(value, self.global_object(), result); } } #[inline] pub fn get_pending_value(&self) -> Option { - if self.vm().is_shutting_down() { - return None; - } - if let Some(value) = self.this_value.get().try_get() { - return js::pending_value_get_cached(value); - } - None + js::pending_value_get_cached(self.live_this_value()?) } #[inline] fn set_target(&self, result: JSValue) { - if self.vm().is_shutting_down() { - return; - } - if let Some(value) = self.this_value.get().try_get() { + if let Some(value) = self.live_this_value() { js::target_set_cached(value, self.global_object(), result); } } #[inline] fn get_target(&self) -> Option { - if self.vm().is_shutting_down() { - return None; - } - if let Some(value) = self.this_value.get().try_get() { - return js::target_get_cached(value); - } - None + js::target_get_cached(self.live_this_value()?) } #[inline] fn set_columns(&self, result: JSValue) { - if self.vm().is_shutting_down() { - return; - } - if let Some(value) = self.this_value.get().try_get() { + if let Some(value) = self.live_this_value() { js::columns_set_cached(value, self.global_object(), result); } } #[inline] fn get_columns(&self) -> Option { - if self.vm().is_shutting_down() { - return None; - } - if let Some(value) = self.this_value.get().try_get() { - return js::columns_get_cached(value); - } - None + js::columns_get_cached(self.live_this_value()?) } #[inline] fn set_binding(&self, result: JSValue) { - if self.vm().is_shutting_down() { - return; - } - if let Some(value) = self.this_value.get().try_get() { + if let Some(value) = self.live_this_value() { js::binding_set_cached(value, self.global_object(), result); } } #[inline] fn get_binding(&self) -> Option { - if self.vm().is_shutting_down() { - return None; - } - if let Some(value) = self.this_value.get().try_get() { - return js::binding_get_cached(value); - } - None + js::binding_get_cached(self.live_this_value()?) } // Helpers for stored back-references. diff --git a/src/sql_jsc/mysql/MySQLConnection.rs b/src/sql_jsc/mysql/MySQLConnection.rs index 2f3a1b94f24..fd70d805a14 100644 --- a/src/sql_jsc/mysql/MySQLConnection.rs +++ b/src/sql_jsc/mysql/MySQLConnection.rs @@ -38,6 +38,7 @@ use crate::mysql::js_mysql_connection::JSMySQLConnection; use crate::mysql::js_mysql_query::JSMySQLQuery; use crate::mysql::my_sql_request_queue::MySQLRequestQueue; use crate::mysql::my_sql_statement::{self as mysql_statement, MySQLStatement, Param}; +use crate::shared::connection_args; pub use bun_sql::mysql::protocol::error_packet::ErrorPacket; // Zig: `pub const Status = ConnectionState;` — re-export so callers can write @@ -322,10 +323,11 @@ impl MySQLConnection { pub fn upgrade_to_tls(&mut self) -> Result<(), FlushQueueError> { // Only adopt if we're currently a plain TCP socket. - let Socket::SocketTcp(tcp) = &self.socket else { - return Ok(()); + let raw = match &self.socket { + Socket::SocketTcp(tcp) => tcp.socket.get(), + _ => None, }; - let uws::InternalSocket::Connected(raw) = tcp.socket else { + let Some(raw) = raw else { return Ok(()); }; @@ -335,56 +337,28 @@ impl MySQLConnection { .as_mut() .mysql_socket_group::(); - // SAFETY: `secure` is set to a live `SSL_CTX*` before TLS upgrade is - // requested (Zig: `this.#secure.?`). - let ssl_ctx = unsafe { - &mut *self - .secure - .expect("secure SSL_CTX must be set before upgradeToTLS") - }; - let server_name = self.tls_config.server_name(); - let sni = if server_name.is_null() { - None - } else { - // SAFETY: `server_name` is a NUL-terminated C string owned by - // `tls_config` for the connection lifetime. - Some(unsafe { bun_core::ffi::cstr(server_name) }) + // Zig: `ext(?*JSMySQLConnection).* = this.getJSConnection()`. + let js_connection = self.get_js_connection(); + // SAFETY: `raw` is a live connected `us_socket_t*`; `secure` is set to + // a live `SSL_CTX*` before TLS upgrade is requested (Zig: + // `this.#secure.?`); `server_name` is a NUL-terminated C string owned + // by `tls_config` for the connection lifetime. + let adopted = unsafe { + crate::jsc::adopt_socket_tls( + raw, + tls_group, + bun_uws::SocketKind::MysqlTls, + &mut *self + .secure + .expect("secure SSL_CTX must be set before upgradeToTLS"), + &self.tls_config, + js_connection, + |s| self.socket = s, + ) }; - // Zig: `@sizeOf(?*JSMySQLConnection)` — `?*T` is an 8-byte null-niche - // optional. The Rust layout-equivalent is `Option>`; using - // `Option<*mut T>` here would request 16 bytes (separate discriminant) - // and desync with the trampoline reader (uws_handlers.rs) which reads - // the slot as `Option>`. - let ext_size = core::mem::size_of::>>() as i32; - - // SAFETY: `raw` is a live connected `us_socket_t*`; adopt_tls may - // realloc and return a different ptr. - let Some(new_socket) = (unsafe { &mut *raw }).adopt_tls( - tls_group, - bun_uws::SocketKind::MysqlTls, - ssl_ctx, - sni, - ext_size, - ext_size, - ) else { + if !adopted { return Err(FlushQueueError::AuthenticationFailed); - }; - - let js_connection = self.get_js_connection(); - let new_socket = new_socket.as_ptr(); - // SAFETY: `new_socket` is a live us_socket_t freshly returned by - // `adopt_tls`; ext storage was sized for - // `Option>` above. One `&mut` reborrow - // drives both safe inherent methods (`ext` / `start_tls_handshake`). - // Zig: `ext(?*JSMySQLConnection).* = this.getJSConnection()`. - let sock = unsafe { &mut *new_socket }; - *sock.ext::>>() = - core::ptr::NonNull::new(js_connection); - self.socket = Socket::SocketTls(uws::SocketTLS { - socket: uws::InternalSocket::Connected(new_socket), - }); - // ext is now repointed; safe to kick the handshake (any dispatch lands here). - sock.start_tls_handshake(); + } Ok(()) } @@ -422,50 +396,20 @@ impl MySQLConnection { self.sequence_id = self.sequence_id.wrapping_add(1); if handshake_success { self.tls_status = TLSStatus::SslOk; - if self.tls_config.reject_unauthorized() != 0 { - // follow the same rules as postgres - // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 - // only reject the connection if reject_unauthorized == true - match self.ssl_mode { - SSLMode::VerifyCa | SSLMode::VerifyFull => { - if ssl_error.error_no != 0 { - self.tls_status = TLSStatus::SslFailed; - return Ok(false); - } - - // VerifyFull additionally requires the certificate identity to - // match the intended host. Absence of a configured server name is - // not a license to skip the check — fail closed. - if self.ssl_mode == SSLMode::VerifyFull { - let servername = self.tls_config.server_name(); - if servername.is_null() { - self.tls_status = TLSStatus::SslFailed; - return Ok(false); - } - // SAFETY: native handle of a connected TLS socket is `SSL*`. - let ssl_ptr: *mut bun_boringssl_sys::SSL = self - .socket - .get_native_handle() - .map(|h| h.cast()) - .unwrap_or(core::ptr::null_mut()); - // SAFETY: `server_name` is a NUL-terminated C string owned by - // `tls_config` for the connection lifetime. - let hostname = unsafe { bun_core::ffi::cstr(servername) }.to_bytes(); - if ssl_ptr.is_null() - || !bun_boringssl::check_server_identity( - // SAFETY: `ssl_ptr` is non-null (checked by the short-circuit above) and live (handshake just succeeded). - unsafe { &mut *ssl_ptr }, - hostname, - ) - { - self.tls_status = TLSStatus::SslFailed; - return Ok(false); - } - } - } - // require is the same as prefer - SSLMode::Require | SSLMode::Prefer | SSLMode::Disable => {} - } + // follow the same rules as postgres + // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 + // only reject the connection if reject_unauthorized == true (require is the same as prefer) + if self.tls_config.reject_unauthorized() != 0 + && matches!(self.ssl_mode, SSLMode::VerifyCa | SSLMode::VerifyFull) + && !connection_args::verify_tls_server( + self.ssl_mode == SSLMode::VerifyFull, + &self.tls_config, + self.socket.get_native_handle(), + ssl_error.error_no, + ) + { + self.tls_status = TLSStatus::SslFailed; + return Ok(false); } self.send_handshake_response()?; return Ok(true); diff --git a/src/sql_jsc/mysql/MySQLContext.rs b/src/sql_jsc/mysql/MySQLContext.rs index 47fc381a231..5afd23c1d48 100644 --- a/src/sql_jsc/mysql/MySQLContext.rs +++ b/src/sql_jsc/mysql/MySQLContext.rs @@ -1,4 +1,6 @@ -use crate::jsc::{CallFrame, JSGlobalObject, JSValue, StrongOptional, VirtualMachineSqlExt as _}; +use crate::jsc::{ + CallFrame, JSGlobalObject, JSValue, JsResult, StrongOptional, VirtualMachineSqlExt as _, +}; #[repr(C)] #[derive(Default)] @@ -10,14 +12,14 @@ pub struct MySQLContext { // TODO(port): bun_jsc::host_fn proc-macro // (Zig: `@export(&JSC.toJSHostFn(init), .{ .name = "MySQLContext__init" })`). -pub(crate) fn init(global: &JSGlobalObject, frame: &CallFrame) -> JSValue { +pub(crate) fn init(global: &JSGlobalObject, frame: &CallFrame) -> JsResult { // `bun_vm()` → `&'static VirtualMachine` (per-thread singleton); `as_mut()` // is the canonical safe escape hatch for the shrinking set of `&mut self` // helpers like `sql_state()` — one audited unsafe lives in bun_jsc. let ctx = &mut global.bun_vm().as_mut().sql_state().mysql_context; ctx.on_query_resolve_fn.set(global, frame.argument(0)); ctx.on_query_reject_fn.set(global, frame.argument(1)); - JSValue::UNDEFINED + Ok(JSValue::UNDEFINED) } // ported from: src/sql_jsc/mysql/MySQLContext.zig diff --git a/src/sql_jsc/mysql/MySQLQuery.rs b/src/sql_jsc/mysql/MySQLQuery.rs index fadf5f72af2..42c38dfa70a 100644 --- a/src/sql_jsc/mysql/MySQLQuery.rs +++ b/src/sql_jsc/mysql/MySQLQuery.rs @@ -40,66 +40,14 @@ pub struct MySQLQuery { } /// Zig: `packed struct(u8) { bigint, simple, pipelined: bool, result_mode: SQLQueryResultMode, _padding: u3 }` -/// Not all fields are `bool`, so per PORTING.md this is a transparent `u8` with shift accessors. -#[repr(transparent)] -#[derive(Copy, Clone, Default)] -struct Flags(u8); - -impl Flags { - const BIGINT: u8 = 1 << 0; - const SIMPLE: u8 = 1 << 1; - const PIPELINED: u8 = 1 << 2; - const RESULT_MODE_SHIFT: u8 = 3; - const RESULT_MODE_MASK: u8 = 0b11 << Self::RESULT_MODE_SHIFT; // SQLQueryResultMode is 2 bits (3 bool + 2 + 3 pad = 8) - - #[inline] - fn bigint(self) -> bool { - self.0 & Self::BIGINT != 0 - } - #[inline] - fn simple(self) -> bool { - self.0 & Self::SIMPLE != 0 - } - #[inline] - fn pipelined(self) -> bool { - self.0 & Self::PIPELINED != 0 - } - #[inline] - fn set_pipelined(&mut self, v: bool) { - if v { - self.0 |= Self::PIPELINED; - } else { - self.0 &= !Self::PIPELINED; - } - } - #[inline] - fn result_mode(self) -> SQLQueryResultMode { - // result_mode bits were written from a valid SQLQueryResultMode - // discriminant (`set_result_mode`); the unreachable 4th bit-state - // traps (matches Zig's safety-checked `@enumFromInt`). - match (self.0 & Self::RESULT_MODE_MASK) >> Self::RESULT_MODE_SHIFT { - 0 => SQLQueryResultMode::Objects, - 1 => SQLQueryResultMode::Values, - 2 => SQLQueryResultMode::Raw, - n => unreachable!("invalid SQLQueryResultMode {n}"), - } - } - #[inline] - fn set_result_mode(&mut self, m: SQLQueryResultMode) { - self.0 = (self.0 & !Self::RESULT_MODE_MASK) | ((m as u8) << Self::RESULT_MODE_SHIFT); - } - #[inline] - fn new(bigint: bool, simple: bool) -> Self { - let mut f = 0u8; - if bigint { - f |= Self::BIGINT; - } - if simple { - f |= Self::SIMPLE; - } - // result_mode default = .objects (assumed discriminant 0) - Self(f) - } +/// Ported as a plain struct (like `PostgresSQLQuery::Flags`) — the packing is not +/// load-bearing on the Rust side. +#[derive(Copy, Clone)] +struct Flags { + bigint: bool, + simple: bool, + pipelined: bool, + result_mode: SQLQueryResultMode, } impl MySQLQuery { @@ -415,13 +363,13 @@ impl MySQLQuery { if !global_object.has_exception() { let _ = global_object.throw_value(mysql_error_to_js( global_object, - Some(b"failed to bind and execute query"), + b"failed to bind and execute query", err, )); } return Err(bun_core::err!("JSError")); } - self.flags.set_pipelined(true); + self.flags.pipelined = true; } } my_sql_statement::Status::Parsing => { @@ -459,7 +407,12 @@ impl MySQLQuery { statement: core::ptr::null_mut(), query, status: Status::Pending, - flags: Flags::new(bigint, simple), + flags: Flags { + bigint, + simple, + pipelined: false, + result_mode: SQLQueryResultMode::Objects, + }, } } @@ -471,7 +424,7 @@ impl MySQLQuery { binding_value: JSValue, ) -> Result<(), bun_core::Error> { // TODO(port): narrow error set - if self.flags.simple() { + if self.flags.simple { debug!("runSimpleQuery"); return self.run_simple_query(connection); } @@ -494,7 +447,7 @@ impl MySQLQuery { #[inline] pub fn set_result_mode(&mut self, result_mode: SQLQueryResultMode) { - self.flags.set_result_mode(result_mode); + self.flags.result_mode = result_mode; } #[inline] @@ -563,22 +516,22 @@ impl MySQLQuery { #[inline] pub fn is_pipelined(&self) -> bool { - self.flags.pipelined() + self.flags.pipelined } #[inline] pub fn is_simple(&self) -> bool { - self.flags.simple() + self.flags.simple } #[inline] pub fn is_bigint_supported(&self) -> bool { - self.flags.bigint() + self.flags.bigint } #[inline] pub fn get_result_mode(&self) -> SQLQueryResultMode { - self.flags.result_mode() + self.flags.result_mode } #[inline] diff --git a/src/sql_jsc/mysql/MySQLStatement.rs b/src/sql_jsc/mysql/MySQLStatement.rs index eee302ae2b6..cbc168b107d 100644 --- a/src/sql_jsc/mysql/MySQLStatement.rs +++ b/src/sql_jsc/mysql/MySQLStatement.rs @@ -1,7 +1,6 @@ use core::cell::Cell; use crate::jsc::{JSGlobalObject, JSValue}; -use bun_collections::StringHashMap; use crate::mysql::protocol::Signature; use crate::shared::CachedStructure; @@ -9,7 +8,6 @@ use crate::shared::sql_data_cell::Flags as DataCellFlags; use bun_sql::mysql::protocol::column_definition41::ColumnDefinition41; use bun_sql::mysql::protocol::error_packet::ErrorPacket; -use bun_sql::shared::ColumnIdentifier; pub use bun_sql::mysql::mysql_param::Param; @@ -122,50 +120,9 @@ impl MySQLStatement { self.execution_flags .remove(ExecutionFlags::NEEDS_DUPLICATE_CHECK); - let mut seen_numbers: Vec = Vec::new(); - let mut seen_fields: StringHashMap<()> = StringHashMap::default(); - seen_fields.reserve(self.columns.len()); - - // iterate backwards - let mut remaining = self.columns.len(); - let mut flags = DataCellFlags::default(); - while remaining > 0 { - remaining -= 1; - let field: &mut ColumnDefinition41 = &mut self.columns[remaining]; - match &field.name_or_index { - ColumnIdentifier::Name(name) => { - // PORT NOTE: reshaped for borrowck — compute `found_existing` before - // mutating `field.name_or_index`. - let found_existing = seen_fields - .get_or_put(name.slice()) - .expect("OOM") - .found_existing; - if found_existing { - // Zig: field.name_or_index.deinit(); — Drop on assignment handles this. - field.name_or_index = ColumnIdentifier::Duplicate; - flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); - } - - flags.insert(DataCellFlags::HAS_NAMED_COLUMNS); - } - ColumnIdentifier::Index(index) => { - let index = *index; - if seen_numbers.contains(&index) { - field.name_or_index = ColumnIdentifier::Duplicate; - flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); - } else { - seen_numbers.push(index); - } - - flags.insert(DataCellFlags::HAS_INDEXED_COLUMNS); - } - ColumnIdentifier::Duplicate => { - flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); - } - } - } - - self.fields_flags = flags; + self.fields_flags = crate::shared::cached_structure::mark_duplicate_columns( + self.columns.iter_mut().map(|c| &mut c.name_or_index), + ); } // PORT NOTE: Zig returns `CachedStructure` by value (struct copy). Returning `&CachedStructure` diff --git a/src/sql_jsc/mysql/MySQLValue.rs b/src/sql_jsc/mysql/MySQLValue.rs index 3192a5eaffb..d81f2a4bb99 100644 --- a/src/sql_jsc/mysql/MySQLValue.rs +++ b/src/sql_jsc/mysql/MySQLValue.rs @@ -3,8 +3,8 @@ //! `CharacterSet`/`FieldType` enums without `JSValue` references. use crate::jsc::{ - IntegerRange, JSGlobalObject, JSGlobalObjectSqlExt as _, JSType, JSValue, JsError, JsResult, - MarkedArgumentBuffer, StringJsc as _, bun_string_jsc, js_error_to_mysql, + ErrorCode, IntegerRange, JSGlobalObject, JSType, JSValue, JsError, JsResult, + MarkedArgumentBuffer, StringJsc as _, js_error_to_mysql, }; use bun_core::zig_string::Slice as ZigStringSlice; use bun_core::{OwnedString, String as BunString}; @@ -47,11 +47,14 @@ pub(crate) fn field_type_from_js( return Ok(FieldType::MYSQL_TYPE_LONGLONG); } return Err(global_object - .err_out_of_range(format_args!( - "The value is out of range. It must be >= {} and <= {}.", - i64::MIN, - u64::MAX - )) + .err( + ErrorCode::OUT_OF_RANGE, + format_args!( + "The value is out of range. It must be >= {} and <= {}.", + i64::MIN, + u64::MAX + ), + ) .throw()); } @@ -137,7 +140,6 @@ pub enum Value { BytesData(Data), Date(DateTime), Time(Time), - // Decimal(Decimal), } /// BLOB parameter bytes. `MySQLQuery.bind()` fills every `Value` before @@ -184,6 +186,38 @@ impl Drop for Bytes { // Value's Zig `deinit` only forwarded to payload deinit; Rust auto-drops enum // payloads (ZigStringSlice, Bytes, Data all impl Drop), so no explicit Drop. +/// The integer branches of `Value::from_js` validate against the full range of +/// the target type, so the bounds are derived from `T` rather than repeated at +/// every call site. +fn int_range(field_name: &'static [u8]) -> IntegerRange { + IntegerRange { + min: T::MIN_I128, + max: T::MAX_I128, + field_name, + ..Default::default() + } +} + +fn validate_int( + global_object: &JSGlobalObject, + value: JSValue, + field_name: &'static [u8], +) -> Result { + global_object + .validate_integer_range::(value, T::ZERO, int_range::(field_name)) + .map_err(js_error_to_mysql) +} + +fn validate_bigint( + global_object: &JSGlobalObject, + value: JSValue, + field_name: &'static [u8], +) -> Result { + global_object + .validate_big_int_range::(value, T::ZERO, int_range::(field_name)) + .map_err(js_error_to_mysql) +} + impl Value { pub fn to_data(&self, field_type: FieldType) -> Result { let mut buffer = [0u8; 15]; // Large enough for all fixed-size types @@ -232,7 +266,6 @@ impl Value { Value::Time(d) => { pos = d.to_binary(field_type, &mut buffer) as usize; } - // Value::Decimal(dec) => return dec.to_binary(field_type), Value::StringData(data) | Value::BytesData(data) => { // TODO(port): Zig returned `data` by value (copy of Data union); // `bun_sql::shared::Data` is not `Clone` in the Rust port, so @@ -280,99 +313,24 @@ impl Value { FieldType::MYSQL_TYPE_TINY => Ok(Value::Bool(value.to_boolean())), FieldType::MYSQL_TYPE_SHORT => { if unsigned { - return Ok(Value::Ushort( - global_object - .validate_integer_range::( - value, - 0, - IntegerRange { - min: u16::MIN as i128, - max: u16::MAX as i128, - field_name: b"u16", - ..Default::default() - }, - ) - .map_err(js_error_to_mysql)?, - )); + Ok(Value::Ushort(validate_int(global_object, value, b"u16")?)) + } else { + Ok(Value::Short(validate_int(global_object, value, b"i16")?)) } - Ok(Value::Short( - global_object - .validate_integer_range::( - value, - 0, - IntegerRange { - min: i16::MIN as i128, - max: i16::MAX as i128, - field_name: b"i16", - ..Default::default() - }, - ) - .map_err(js_error_to_mysql)?, - )) } FieldType::MYSQL_TYPE_LONG => { if unsigned { - return Ok(Value::Uint( - global_object - .validate_integer_range::( - value, - 0, - IntegerRange { - min: u32::MIN as i128, - max: u32::MAX as i128, - field_name: b"u32", - ..Default::default() - }, - ) - .map_err(js_error_to_mysql)?, - )); + Ok(Value::Uint(validate_int(global_object, value, b"u32")?)) + } else { + Ok(Value::Int(validate_int(global_object, value, b"i32")?)) } - Ok(Value::Int( - global_object - .validate_integer_range::( - value, - 0, - IntegerRange { - min: i32::MIN as i128, - max: i32::MAX as i128, - field_name: b"i32", - ..Default::default() - }, - ) - .map_err(js_error_to_mysql)?, - )) } FieldType::MYSQL_TYPE_LONGLONG => { if unsigned { - return Ok(Value::Ulong( - global_object - .validate_big_int_range::( - value, - 0, - IntegerRange { - min: 0, - max: u64::MAX as i128, - field_name: b"u64", - ..Default::default() - }, - ) - .map_err(js_error_to_mysql)?, - )); + Ok(Value::Ulong(validate_bigint(global_object, value, b"u64")?)) + } else { + Ok(Value::Long(validate_bigint(global_object, value, b"i64")?)) } - Ok(Value::Long( - global_object - .validate_big_int_range::( - value, - 0, - IntegerRange { - min: i64::MIN as i128, - max: i64::MAX as i128, - field_name: b"i64", - ..Default::default() - }, - ) - .map_err(js_error_to_mysql)?, - )) } FieldType::MYSQL_TYPE_FLOAT => Ok(Value::Float( @@ -798,15 +756,6 @@ impl Time { } } - pub fn to_unix_timestamp(&self) -> i64 { - let mut total_ms: i64 = 0; - total_ms = total_ms.saturating_add((self.days as i64).saturating_mul(86400000)); - total_ms = total_ms.saturating_add((self.hours as i64).saturating_mul(3600000)); - total_ms = total_ms.saturating_add((self.minutes as i64).saturating_mul(60000)); - total_ms = total_ms.saturating_add((self.seconds as i64).saturating_mul(1000)); - total_ms - } - pub fn from_data(data: &Data) -> Result { // TODO(port): narrow error set Ok(Self::from_binary(data.slice())) @@ -875,50 +824,6 @@ impl Time { } } -pub struct Decimal { - // MySQL DECIMAL is stored as a sequence of base-10 digits - pub digits: Box<[u8]>, - pub scale: u8, - pub negative: bool, -} - -impl Decimal { - pub fn to_js(&self, global_object: &JSGlobalObject) -> JSValue { - // PERF(port): was stack-fallback (std.heap.stackFallback(64, ...)) — profile if it shows up on a hot path. - let mut str: Vec = Vec::new(); - - if self.negative { - str.push(b'-'); - } - - let decimal_pos = self.digits.len() - self.scale as usize; - for (i, digit) in self.digits.iter().enumerate() { - if i == decimal_pos && self.scale > 0 { - str.push(b'.'); - } - str.push(digit + b'0'); - } - - bun_string_jsc::create_utf8_for_js(global_object, &str).unwrap_or(JSValue::ZERO) - } - - pub fn to_binary(&self, _field_type: FieldType) -> Result { - // Zig: `bun.todoPanic(@src(), "Decimal.toBinary not implemented", .{});` - // Intentional shipped runtime "feature not yet implemented" — not a - // porting placeholder. The `Decimal` arm of `Value` is commented out, - // so this is unreachable today. - bun_core::todo_panic!("Decimal.toBinary not implemented") - } - - // pub fn from_data(data: &Data) -> Result { - // Ok(Self::from_binary(data.slice())) - // } - - // pub fn from_binary(_: &[u8]) -> Decimal { - // bun_core::todo_panic!("Decimal.fromBinary not implemented") - // } -} - // Helper functions for date calculations fn is_leap_year(year: u16) -> bool { (year.is_multiple_of(4) && !year.is_multiple_of(100)) || year.is_multiple_of(400) diff --git a/src/sql_jsc/mysql/protocol/DecodeBinaryValue.rs b/src/sql_jsc/mysql/protocol/DecodeBinaryValue.rs index fe0cf0addf2..d717a272ccf 100644 --- a/src/sql_jsc/mysql/protocol/DecodeBinaryValue.rs +++ b/src/sql_jsc/mysql/protocol/DecodeBinaryValue.rs @@ -1,7 +1,6 @@ use crate::jsc::JSGlobalObject; use crate::mysql::my_sql_value::{DateTime, Time}; use crate::shared::sql_data_cell::SQLDataCell; -use crate::shared::sql_data_cell::{Tag as CellTag, Value as CellValue}; use bun_sql::mysql::mysql_types as types; use bun_sql::mysql::mysql_types::FieldType; use bun_sql::mysql::protocol::new_reader::{NewReader, ReaderContext}; @@ -38,18 +37,9 @@ pub fn decode_binary_value( } let val = reader.byte()?; if unsigned { - return Ok(SQLDataCell { - tag: CellTag::Uint4, - value: CellValue { uint4: val as u32 }, - ..Default::default() - }); + return Ok(SQLDataCell::uint4(val as u32)); } - let ival: i8 = val as i8; - Ok(SQLDataCell { - tag: CellTag::Int4, - value: CellValue { int4: ival as i32 }, - ..Default::default() - }) + Ok(SQLDataCell::int4(val as i8 as i32)) } FieldType::MYSQL_TYPE_SHORT => { if raw { @@ -57,21 +47,9 @@ pub fn decode_binary_value( return Ok(SQLDataCell::raw(Some(&data))); } if unsigned { - return Ok(SQLDataCell { - tag: CellTag::Uint4, - value: CellValue { - uint4: reader.int::()? as u32, - }, - ..Default::default() - }); + return Ok(SQLDataCell::uint4(reader.int::()? as u32)); } - Ok(SQLDataCell { - tag: CellTag::Int4, - value: CellValue { - int4: reader.int::()? as i32, - }, - ..Default::default() - }) + Ok(SQLDataCell::int4(reader.int::()? as i32)) } FieldType::MYSQL_TYPE_YEAR => { // Binary protocol sends YEAR as a fixed 2-byte unsigned field; @@ -80,13 +58,7 @@ pub fn decode_binary_value( let data = reader.read(2)?; return Ok(SQLDataCell::raw(Some(&data))); } - Ok(SQLDataCell { - tag: CellTag::Uint4, - value: CellValue { - uint4: reader.int::()? as u32, - }, - ..Default::default() - }) + Ok(SQLDataCell::uint4(reader.int::()? as u32)) } FieldType::MYSQL_TYPE_INT24 => { if raw { @@ -96,21 +68,9 @@ pub fn decode_binary_value( return Ok(SQLDataCell::raw(Some(&data.substring(0, 3)))); } if unsigned { - return Ok(SQLDataCell { - tag: CellTag::Uint4, - value: CellValue { - uint4: reader.int_u24()?, - }, - ..Default::default() - }); + return Ok(SQLDataCell::uint4(reader.int_u24()?)); } - Ok(SQLDataCell { - tag: CellTag::Int4, - value: CellValue { - int4: reader.int_i24()?, - }, - ..Default::default() - }) + Ok(SQLDataCell::int4(reader.int_i24()?)) } FieldType::MYSQL_TYPE_LONG => { if raw { @@ -118,21 +78,9 @@ pub fn decode_binary_value( return Ok(SQLDataCell::raw(Some(&data))); } if unsigned { - return Ok(SQLDataCell { - tag: CellTag::Uint4, - value: CellValue { - uint4: reader.int::()?, - }, - ..Default::default() - }); + return Ok(SQLDataCell::uint4(reader.int::()?)); } - Ok(SQLDataCell { - tag: CellTag::Int4, - value: CellValue { - int4: reader.int::()?, - }, - ..Default::default() - }) + Ok(SQLDataCell::int4(reader.int::()?)) } FieldType::MYSQL_TYPE_LONGLONG => { if raw { @@ -141,107 +89,45 @@ pub fn decode_binary_value( if unsigned { let val = reader.int::()?; if val <= u32::MAX as u64 { - return Ok(SQLDataCell { - tag: CellTag::Uint4, - value: CellValue { - uint4: u32::try_from(val).expect("int cast"), - }, - ..Default::default() - }); + return Ok(SQLDataCell::uint4(u32::try_from(val).expect("int cast"))); } if bigint { - return Ok(SQLDataCell { - tag: CellTag::Uint8, - value: CellValue { uint8: val }, - ..Default::default() - }); + return Ok(SQLDataCell::uint8(val)); } let mut buffer = bun_core::fmt::ItoaBuf::new(); let slice = bun_core::fmt::itoa(&mut buffer, val); - return Ok(SQLDataCell { - tag: CellTag::String, - value: CellValue { - string: if !slice.is_empty() { - clone_utf8_wtf_impl(slice) - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }); + return Ok(SQLDataCell::string(slice)); } let val = reader.int::()?; if val >= i32::MIN as i64 && val <= i32::MAX as i64 { - return Ok(SQLDataCell { - tag: CellTag::Int4, - value: CellValue { - int4: i32::try_from(val).expect("int cast"), - }, - ..Default::default() - }); + return Ok(SQLDataCell::int4(i32::try_from(val).expect("int cast"))); } if bigint { - return Ok(SQLDataCell { - tag: CellTag::Int8, - value: CellValue { int8: val }, - ..Default::default() - }); + return Ok(SQLDataCell::int8(val)); } let mut buffer = bun_core::fmt::ItoaBuf::new(); let slice = bun_core::fmt::itoa(&mut buffer, val); - Ok(SQLDataCell { - tag: CellTag::String, - value: CellValue { - string: if !slice.is_empty() { - clone_utf8_wtf_impl(slice) - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(slice)) } FieldType::MYSQL_TYPE_FLOAT => { if raw { let data = reader.read(4)?; return Ok(SQLDataCell::raw(Some(&data))); } - Ok(SQLDataCell { - tag: CellTag::Float8, - value: CellValue { - float8: f32::from_bits(reader.int::()?) as f64, - }, - ..Default::default() - }) + Ok(SQLDataCell::float8( + f32::from_bits(reader.int::()?) as f64 + )) } FieldType::MYSQL_TYPE_DOUBLE => { if raw { let data = reader.read(8)?; return Ok(SQLDataCell::raw(Some(&data))); } - Ok(SQLDataCell { - tag: CellTag::Float8, - value: CellValue { - float8: f64::from_bits(reader.int::()?), - }, - ..Default::default() - }) + Ok(SQLDataCell::float8(f64::from_bits(reader.int::()?))) } FieldType::MYSQL_TYPE_TIME => { match reader.byte()? { - 0 => { - let slice = b"00:00:00"; - Ok(SQLDataCell { - tag: CellTag::String, - value: CellValue { - string: clone_utf8_wtf_impl(slice), - }, - free_value: 1, - ..Default::default() - }) - } + 0 => Ok(SQLDataCell::string(b"00:00:00")), l @ (8 | 12) => { let data = reader.read(l as usize)?; let time = Time::from_data(&data)?; @@ -279,18 +165,7 @@ pub fn decode_binary_value( break 'brk &buffer[..32 - remaining]; }; // PORT NOTE: reshaped for borrowck — compute remaining before re-borrowing buffer - Ok(SQLDataCell { - tag: CellTag::String, - value: CellValue { - string: if !slice.is_empty() { - clone_utf8_wtf_impl(slice) - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(slice)) } _ => Err(bun_core::err!("InvalidBinaryValue")), } @@ -301,11 +176,7 @@ pub fn decode_binary_value( // A zero-length binary DATETIME is MySQL's "0000-00-00 00:00:00" // sentinel — surface it as Invalid Date (NaN), not the Unix epoch, // so it agrees with the text path's from_text(). - 0 => Ok(SQLDataCell { - tag: CellTag::Date, - value: CellValue { date: f64::NAN }, - ..Default::default() - }), + 0 => Ok(SQLDataCell::date(f64::NAN)), l @ (11 | 7 | 4) => { let data = reader.read(l as usize)?; let time = DateTime::from_data(&data)?; @@ -316,11 +187,7 @@ pub fn decode_binary_value( bun_jsc::JsError::Terminated => bun_core::err!("Terminated"), bun_jsc::JsError::Thrown => bun_core::err!("Thrown"), })?; - Ok(SQLDataCell { - tag: CellTag::Date, - value: CellValue { date: ts }, - ..Default::default() - }) + Ok(SQLDataCell::date(ts)) } _ => Err(bun_core::err!("InvalidBinaryValue")), }, @@ -335,19 +202,7 @@ pub fn decode_binary_value( return Ok(SQLDataCell::raw(Some(&data))); } let string_data = reader.encode_len_string()?; - let slice = string_data.slice(); - Ok(SQLDataCell { - tag: CellTag::String, - value: CellValue { - string: if !slice.is_empty() { - clone_utf8_wtf_impl(slice) - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(string_data.slice())) } // When the column contains a binary string we return a Buffer otherwise a string @@ -373,19 +228,7 @@ pub fn decode_binary_value( if binary && character_set == BINARY_CHARSET { return Ok(SQLDataCell::raw(Some(&string_data))); } - let slice = string_data.slice(); - Ok(SQLDataCell { - tag: CellTag::String, - value: CellValue { - string: if !slice.is_empty() { - clone_utf8_wtf_impl(slice) - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(string_data.slice())) } FieldType::MYSQL_TYPE_JSON => { @@ -394,36 +237,14 @@ pub fn decode_binary_value( return Ok(SQLDataCell::raw(Some(&data))); } let string_data = reader.encode_len_string()?; - let slice = string_data.slice(); - Ok(SQLDataCell { - tag: CellTag::Json, - value: CellValue { - json: if !slice.is_empty() { - clone_utf8_wtf_impl(slice) - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::json(string_data.slice())) } FieldType::MYSQL_TYPE_BIT => { // BIT(1) is a special case, it's a boolean if column_length == 1 { let data = reader.encode_len_string()?; let slice = data.slice(); - Ok(SQLDataCell { - tag: CellTag::Bool, - value: CellValue { - bool_: if !slice.is_empty() && slice[0] == 1 { - 1 - } else { - 0 - }, - }, - ..Default::default() - }) + Ok(SQLDataCell::bool_(!slice.is_empty() && slice[0] == 1)) } else { let data = reader.encode_len_string()?; Ok(SQLDataCell::raw(Some(&data))) @@ -436,11 +257,4 @@ pub fn decode_binary_value( } } -// Zig accesses `bun.String.cloneUTF8(slice).value.WTFStringImpl` directly (union field); -// `leak_wtf_impl()` is the Rust equivalent — transfers the +1 ref to the cell (`free_value = 1`). -#[inline] -fn clone_utf8_wtf_impl(slice: &[u8]) -> bun_core::WTFStringImpl { - bun_core::String::clone_utf8(slice).leak_wtf_impl() -} - // ported from: src/sql_jsc/mysql/protocol/DecodeBinaryValue.zig diff --git a/src/sql_jsc/mysql/protocol/ResultSet.rs b/src/sql_jsc/mysql/protocol/ResultSet.rs index 6fcb599ceae..249f4481d1d 100644 --- a/src/sql_jsc/mysql/protocol/ResultSet.rs +++ b/src/sql_jsc/mysql/protocol/ResultSet.rs @@ -1,8 +1,5 @@ -use core::ptr; - -use crate::jsc::{ExternColumnIdentifier, JSGlobalObject, JSValue}; +use crate::jsc::{JSGlobalObject, JSValue}; use crate::mysql::my_sql_value::DateTime; -use bun_core::String as BunString; use bun_core::parse_int; use bun_sql::mysql::protocol::ColumnDefinition41; @@ -15,7 +12,7 @@ use bun_sql::shared::Data; use bun_sql::shared::SQLQueryResultMode; use crate::shared::CachedStructure; -use crate::shared::sql_data_cell::{Flags as SQLDataCellFlags, SQLDataCell, Tag, Value}; +use crate::shared::sql_data_cell::{Flags as SQLDataCellFlags, SQLDataCell}; use super::decode_binary_value::{self, decode_binary_value}; @@ -44,25 +41,16 @@ impl<'a> Row<'a> { // PORT NOTE: Zig `?CachedStructure` is by-value; passed by ref here because CachedStructure is non-Copy (owns Strong + Box). cached_structure: Option<&CachedStructure>, ) -> crate::jsc::JsResult { - let mut names: *mut ExternColumnIdentifier = ptr::null_mut(); - let mut names_count: u32 = 0; - if let Some(c) = cached_structure { - if let Some(f) = c.fields.as_deref() { - names = f.as_ptr().cast_mut(); - names_count = f.len() as u32; - } - } - - SQLDataCell::construct_object_from_data_cell( + let count = self.values.len() as u32; + SQLDataCell::to_js_object( global_object, array, structure, - self.values.as_mut_ptr(), - self.values.len() as u32, + &mut self.values, + count, flags, result_mode as u8, - names, - names_count, + cached_structure, ) } @@ -93,143 +81,71 @@ impl<'a> Row<'a> { match column.column_type { MYSQL_TYPE_FLOAT | MYSQL_TYPE_DOUBLE => { let val: f64 = bun_core::parse_double(value.slice()).unwrap_or(f64::NAN); - *cell = SQLDataCell { - tag: Tag::Float8, - value: Value { float8: val }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::float8(val); } // YEAR arrives as a bare ASCII integer in the text protocol; parse it // like SHORT so `.simple()` returns the same JS number as the binary path. MYSQL_TYPE_TINY | MYSQL_TYPE_SHORT | MYSQL_TYPE_YEAR => { if column.flags.contains(ColumnFlags::UNSIGNED) { let val: u16 = parse_int::(value.slice(), 10).unwrap_or(0); - *cell = SQLDataCell { - tag: Tag::Uint4, - value: Value { uint4: val as u32 }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::uint4(val as u32); } else { let val: i16 = parse_int::(value.slice(), 10).unwrap_or(0); - *cell = SQLDataCell { - tag: Tag::Int4, - value: Value { int4: val as i32 }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::int4(val as i32); } } MYSQL_TYPE_LONG => { if column.flags.contains(ColumnFlags::UNSIGNED) { let val: u32 = parse_int::(value.slice(), 10).unwrap_or(0); - *cell = SQLDataCell { - tag: Tag::Uint4, - value: Value { uint4: val }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::uint4(val); } else { let val: i32 = parse_int::(value.slice(), 10).unwrap_or(i32::MIN); - *cell = SQLDataCell { - tag: Tag::Int4, - value: Value { int4: val }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::int4(val); } } MYSQL_TYPE_INT24 => { if column.flags.contains(ColumnFlags::UNSIGNED) { let val: u32 = parse_int::(value.slice(), 10).unwrap_or(0); - *cell = SQLDataCell { - tag: Tag::Uint4, - value: Value { uint4: val }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::uint4(val); } else { // std.math.minInt(i24) == -8_388_608 let val: i32 = parse_int::(value.slice(), 10).unwrap_or(-8_388_608); - *cell = SQLDataCell { - tag: Tag::Int4, - value: Value { int4: val }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::int4(val); } } MYSQL_TYPE_LONGLONG => { if column.flags.contains(ColumnFlags::UNSIGNED) { let val: u64 = parse_int::(value.slice(), 10).unwrap_or(0); if val <= u32::MAX as u64 { - *cell = SQLDataCell { - tag: Tag::Uint4, - value: Value { - uint4: u32::try_from(val).expect("int cast"), - }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::uint4(u32::try_from(val).expect("int cast")); return; } if self.bigint { - *cell = SQLDataCell { - tag: Tag::Uint8, - value: Value { uint8: val }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::uint8(val); return; } } else { let val: i64 = parse_int::(value.slice(), 10).unwrap_or(0); if val >= i32::MIN as i64 && val <= i32::MAX as i64 { - *cell = SQLDataCell { - tag: Tag::Int4, - value: Value { - int4: i32::try_from(val).expect("int cast"), - }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::int4(i32::try_from(val).expect("int cast")); return; } if self.bigint { - *cell = SQLDataCell { - tag: Tag::Int8, - value: Value { int8: val }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::int8(val); return; } } - let slice = value.slice(); - *cell = SQLDataCell { - tag: Tag::String, - value: Value { - string: clone_wtf_string_or_null(slice), - }, - free_value: 1, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::string(value.slice()); } MYSQL_TYPE_JSON => { - let slice = value.slice(); - *cell = SQLDataCell { - tag: Tag::Json, - value: Value { - json: clone_wtf_string_or_null(slice), - }, - free_value: 1, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::json(value.slice()); } MYSQL_TYPE_TIME => { // lets handle TIME special case as string // -838:59:50 to 838:59:59 is valid - let slice = value.slice(); - *cell = SQLDataCell { - tag: Tag::String, - value: Value { - string: clone_wtf_string_or_null(slice), - }, - free_value: 1, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::string(value.slice()); } MYSQL_TYPE_DATE | MYSQL_TYPE_DATETIME | MYSQL_TYPE_TIMESTAMP => { // MySQL's DATE/DATETIME/TIMESTAMP text has no timezone, so parse @@ -242,44 +158,22 @@ impl<'a> Row<'a> { Some(dt) => dt.to_js_timestamp(self.global_object).unwrap_or(f64::NAN), None => f64::NAN, }; - *cell = SQLDataCell { - tag: Tag::Date, - value: Value { date }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::date(date); } // NEWDECIMAL is always sent as an ASCII decimal string regardless of the // column's BINARY flag / charset. Computed decimals (SUM/AVG/arithmetic/CAST) // carry the BINARY flag and charset 63, so the catch-all arm's binary-charset // heuristic would wrongly return them as a Buffer. MYSQL_TYPE_NEWDECIMAL => { - let slice = value.slice(); - *cell = SQLDataCell { - tag: Tag::String, - value: Value { - string: clone_wtf_string_or_null(slice), - }, - free_value: 1, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::string(value.slice()); } MYSQL_TYPE_BIT => { // BIT(1) is a special case, it's a boolean if column.column_length == 1 { let slice = value.slice(); - *cell = SQLDataCell { - tag: Tag::Bool, - value: Value { - bool_: if !slice.is_empty() && slice[0] == 1 { - 1 - } else { - 0 - }, - }, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::bool_(!slice.is_empty() && slice[0] == 1); } else { - *cell = SQLDataCell::raw(value); + *cell = SQLDataCell::raw(Some(value)); } } _ => { @@ -290,17 +184,9 @@ impl<'a> Row<'a> { if column.flags.contains(ColumnFlags::BINARY) && column.character_set == decode_binary_value::BINARY_CHARSET { - *cell = SQLDataCell::raw(value); + *cell = SQLDataCell::raw(Some(value)); } else { - let slice = value.slice(); - *cell = SQLDataCell { - tag: Tag::String, - value: Value { - string: clone_wtf_string_or_null(slice), - }, - free_value: 1, - ..SQLDataCell::default() - }; + *cell = SQLDataCell::string(value.slice()); } } } @@ -310,15 +196,7 @@ impl<'a> Row<'a> { &mut self, reader: NewReader, ) -> Result<(), AnyMySQLError> { - let cells = vec![ - SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..SQLDataCell::default() - }; - self.columns.len() - ] - .into_boxed_slice(); + let cells = vec![SQLDataCell::null(); self.columns.len()].into_boxed_slice(); let mut cells = scopeguard::guard(cells, |mut cells| { for value in cells.iter_mut() { value.deinit(); @@ -336,15 +214,11 @@ impl<'a> Row<'a> { // NULL value reader.skip(result.bytes_read); // this dont matter if is raw because we will sent as null too like in postgres - *value = SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..SQLDataCell::default() - }; + *value = SQLDataCell::null(); } else { if self.raw { let data = reader.encode_len_string()?; - *value = SQLDataCell::raw(&data); + *value = SQLDataCell::raw(Some(&data)); } else { reader.skip(result.bytes_read); let string_data = @@ -382,15 +256,7 @@ impl<'a> Row<'a> { let bitmap_bytes = (self.columns.len() + 7 + 2) / 8; let null_bitmap = reader.read(bitmap_bytes)?; - let cells = vec![ - SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..SQLDataCell::default() - }; - self.columns.len() - ] - .into_boxed_slice(); + let cells = vec![SQLDataCell::null(); self.columns.len()].into_boxed_slice(); let mut cells = scopeguard::guard(cells, |mut cells| { for value in cells.iter_mut() { value.deinit(); @@ -405,11 +271,7 @@ impl<'a> Row<'a> { let is_null = (null_bitmap.slice()[byte_pos] & (1u8 << bit_pos)) != 0; if is_null { - *value = SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..SQLDataCell::default() - }; + *value = SQLDataCell::null(); continue; } @@ -460,18 +322,4 @@ impl<'a> Drop for Row<'a> { } } -// ─── helpers ────────────────────────────────────────────────────────────── - -#[inline] -fn clone_wtf_string_or_null(slice: &[u8]) -> bun_core::WTFStringImpl { - // Zig: `bun.String.cloneUTF8(slice).value.WTFStringImpl` — extracts the raw - // WTFStringImpl* from a freshly-cloned bun.String (ownership transferred to the cell, - // freed via `free_value = 1`). - if !slice.is_empty() { - BunString::clone_utf8(slice).leak_wtf_impl() - } else { - ptr::null_mut() - } -} - // ported from: src/sql_jsc/mysql/protocol/ResultSet.zig diff --git a/src/sql_jsc/mysql/protocol/Signature.rs b/src/sql_jsc/mysql/protocol/Signature.rs index 4a201300e0e..5bacf263f06 100644 --- a/src/sql_jsc/mysql/protocol/Signature.rs +++ b/src/sql_jsc/mysql/protocol/Signature.rs @@ -22,32 +22,6 @@ impl Signature { // `deinit` deleted — body only freed owned slices; `Box<[T]>` fields drop automatically. - pub fn hash(&self) -> u64 { - // Hash `name` followed by each param's `(type, flags)` field-by-field. - // - // This intentionally does NOT reinterpret `&[Param]` as `&[u8]` (the Zig - // port originally mirrored `std.mem.sliceAsBytes`): `Param` has default - // `repr(Rust)` with a `u8` enum + `u16` bitflags, leaving one padding - // byte. Exposing padding through `&[u8]` reads uninitialized memory and - // is UB. The hash is a process-local prepared-statement cache key, so it - // only needs to be self-consistent — byte-for-byte layout parity with Zig - // is not required. - // - // PERF(port): Zig fed two slices into a streaming Wyhash; bun_wyhash - // currently lacks the std-compatible streaming `Wyhash` type. Concatenate - // into a temp Vec until that lands. - // TODO(port): bun_wyhash::Wyhash (streaming std-compatible API) - const BYTES_PER_PARAM: usize = 1 /* FieldType */ + 2 /* ColumnFlags */; - let mut buf: Vec = - Vec::with_capacity(self.name.len() + self.fields.len() * BYTES_PER_PARAM); - buf.extend_from_slice(&self.name); - for p in self.fields.iter() { - buf.push(p.r#type as u8); - buf.extend_from_slice(&p.flags.to_int().to_ne_bytes()); - } - bun_wyhash::hash(&buf) - } - // TODO(port): narrow error set (mixes JS errors, alloc, and InvalidQueryBinding) pub fn generate( global_object: &JSGlobalObject, diff --git a/src/sql_jsc/mysql/protocol/any_mysql_error_jsc.rs b/src/sql_jsc/mysql/protocol/any_mysql_error_jsc.rs index 0b9743668b6..b00a963b278 100644 --- a/src/sql_jsc/mysql/protocol/any_mysql_error_jsc.rs +++ b/src/sql_jsc/mysql/protocol/any_mysql_error_jsc.rs @@ -26,64 +26,12 @@ impl IntoAnyMySQLError for bun_core::Error { } } -/// Zig `?[]const u8`. Callers pass either a bare byte-ish value (`&str`, -/// `&[u8]`, `&[u8; N]`, `&Vec`) or the same wrapped in `Option<_>`, so -/// this trait — rather than `AsRef<[u8]>` directly — lets one signature -/// accept both shapes without touching every callsite. -pub(crate) trait MaybeBytes { - fn as_maybe_bytes(&self) -> Option<&[u8]>; -} -impl MaybeBytes for str { - #[inline] - fn as_maybe_bytes(&self) -> Option<&[u8]> { - Some(self.as_bytes()) - } -} -impl MaybeBytes for [u8] { - #[inline] - fn as_maybe_bytes(&self) -> Option<&[u8]> { - Some(self) - } -} -impl MaybeBytes for [u8; N] { - #[inline] - fn as_maybe_bytes(&self) -> Option<&[u8]> { - Some(self.as_slice()) - } -} -impl MaybeBytes for Vec { - #[inline] - fn as_maybe_bytes(&self) -> Option<&[u8]> { - Some(self.as_slice()) - } -} -impl MaybeBytes for String { - #[inline] - fn as_maybe_bytes(&self) -> Option<&[u8]> { - Some(self.as_bytes()) - } -} -impl MaybeBytes for &T { - #[inline] - fn as_maybe_bytes(&self) -> Option<&[u8]> { - (**self).as_maybe_bytes() - } -} -impl MaybeBytes for Option { - #[inline] - fn as_maybe_bytes(&self) -> Option<&[u8]> { - self.as_ref().and_then(|b| b.as_maybe_bytes()) - } -} - pub(crate) fn mysql_error_to_js( global_object: &JSGlobalObject, - // Zig: `?[]const u8` — `message orelse @errorName(err)`. - message: impl MaybeBytes, + message: &[u8], err: impl IntoAnyMySQLError, ) -> JSValue { let name = err.mysql_error_name(); - let msg: &[u8] = message.as_maybe_bytes().unwrap_or(name.as_bytes()); let code: &'static [u8] = match name { "ConnectionClosed" => b"ERR_MYSQL_CONNECTION_CLOSED", @@ -134,7 +82,7 @@ pub(crate) fn mysql_error_to_js( create_mysql_error( global_object, - msg, + message, &MySQLErrorOptions { code, errno: None, diff --git a/src/sql_jsc/postgres.rs b/src/sql_jsc/postgres.rs index 6b8dcdc4d09..a845f15f8de 100644 --- a/src/sql_jsc/postgres.rs +++ b/src/sql_jsc/postgres.rs @@ -64,21 +64,9 @@ pub mod postgres_request; pub mod data_cell; pub mod types { - #[path = "bool.rs"] - pub mod r#bool; - - #[path = "bytea.rs"] - pub mod bytea; - #[path = "date.rs"] pub mod date; - #[path = "json.rs"] - pub mod json; - - #[path = "PostgresString.rs"] - pub mod postgres_string; - #[path = "tag_jsc.rs"] pub mod tag_jsc; } diff --git a/src/sql_jsc/postgres/DataCell.rs b/src/sql_jsc/postgres/DataCell.rs index 71ab83e6706..b7b010bbffa 100644 --- a/src/sql_jsc/postgres/DataCell.rs +++ b/src/sql_jsc/postgres/DataCell.rs @@ -134,11 +134,7 @@ fn parse_array( return Ok(SQLDataCell { tag: Tag::Array, value: Value { - array: Array { - ptr: core::ptr::null_mut(), - len: 0, - cap: 0, - }, + array: Array::default(), }, ..Default::default() }); @@ -225,14 +221,10 @@ fn parse_array( let date_str = &slice[1..current_idx]; let mut str = BunString::init(date_str); // defer str.deref() → Drop on BunString - array.push(SQLDataCell { - tag: Tag::Date, - value: Value { - date: crate::jsc::bun_string_jsc::parse_date(&mut str, global_object) - .map_err(crate::jsc::js_error_to_postgres)?, - }, - ..Default::default() - }); + array.push(SQLDataCell::date( + crate::jsc::bun_string_jsc::parse_date(&mut str, global_object) + .map_err(crate::jsc::js_error_to_postgres)?, + )); slice = try_slice(slice, current_idx + 1); continue; @@ -249,18 +241,7 @@ fn parse_array( }; let unescaped = unescape_postgres_string(str_bytes, buffer) .map_err(|_| AnyPostgresError::InvalidByteSequence)?; - array.push(SQLDataCell { - tag: Tag::Json, - value: Value { - json: if !unescaped.is_empty() { - BunString::clone_utf8(unescaped).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }); + array.push(SQLDataCell::json(unescaped)); slice = try_slice(slice, current_idx + 1); continue; } @@ -269,14 +250,7 @@ fn parse_array( let str_bytes = &slice[1..current_idx]; if str_bytes.is_empty() { // empty string - array.push(SQLDataCell { - tag: Tag::String, - value: Value { - string: core::ptr::null_mut(), - }, - free_value: 1, - ..Default::default() - }); + array.push(SQLDataCell::string(b"")); slice = try_slice(slice, current_idx + 1); continue; } @@ -290,18 +264,7 @@ fn parse_array( }; let string_bytes = unescape_postgres_string(str_bytes, buffer) .map_err(|_| AnyPostgresError::InvalidByteSequence)?; - array.push(SQLDataCell { - tag: Tag::String, - value: Value { - string: if !string_bytes.is_empty() { - BunString::clone_utf8(string_bytes).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }); + array.push(SQLDataCell::string(string_bytes)); slice = try_slice(slice, current_idx + 1); continue; @@ -357,45 +320,22 @@ fn parse_array( let element = &slice[0..current_idx]; // lets handle NULL case here, if is a string "NULL" it will have quotes, if its a NULL it will be just NULL if element == b"NULL" { - array.push(SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..Default::default() - }); + array.push(SQLDataCell::null()); slice = try_slice(slice, current_idx); continue; } if array_type == types::Tag::date_array { let mut str = BunString::init(element); - array.push(SQLDataCell { - tag: Tag::Date, - value: Value { date: crate::jsc::bun_string_jsc::parse_date(&mut str, global_object).map_err(crate::jsc::js_error_to_postgres)? }, - ..Default::default() - }); + array.push(SQLDataCell::date( + crate::jsc::bun_string_jsc::parse_date(&mut str, global_object) + .map_err(crate::jsc::js_error_to_postgres)?, + )); } else { // the only escape sequency possible here is \b if element == b"\\b" { - array.push(SQLDataCell { - tag: Tag::String, - value: Value { - string: BunString::clone_utf8(b"\x08").leak_wtf_impl(), - }, - free_value: 1, - ..Default::default() - }); + array.push(SQLDataCell::string(b"\x08")); } else { - array.push(SQLDataCell { - tag: Tag::String, - value: Value { - string: if !element.is_empty() { - BunString::clone_utf8(element).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }); + array.push(SQLDataCell::string(element)); } } slice = try_slice(slice, current_idx); @@ -411,21 +351,13 @@ fn parse_array( } if slice.len() >= 4 { if &slice[0..4] == b"NULL" { - array.push(SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..Default::default() - }); + array.push(SQLDataCell::null()); slice = try_slice(slice, 4); continue; } } if &slice[0..3] == b"NaN" { - array.push(SQLDataCell { - tag: Tag::Float8, - value: Value { float8: f64::NAN }, - ..Default::default() - }); + array.push(SQLDataCell::float8(f64::NAN)); slice = try_slice(slice, 3); continue; } @@ -438,21 +370,13 @@ fn parse_array( return Err(AnyPostgresError::UnsupportedArrayFormat); } if &slice[0..5] == b"false" { - array.push(SQLDataCell { - tag: Tag::Bool, - value: Value { bool_: 0 }, - ..Default::default() - }); + array.push(SQLDataCell::bool_(false)); slice = try_slice(slice, 5); continue; } return Err(AnyPostgresError::UnsupportedArrayFormat); } else { - array.push(SQLDataCell { - tag: Tag::Bool, - value: Value { bool_: 0 }, - ..Default::default() - }); + array.push(SQLDataCell::bool_(false)); slice = try_slice(slice, 1); continue; } @@ -464,21 +388,13 @@ fn parse_array( return Err(AnyPostgresError::UnsupportedArrayFormat); } if &slice[0..4] == b"true" { - array.push(SQLDataCell { - tag: Tag::Bool, - value: Value { bool_: 1 }, - ..Default::default() - }); + array.push(SQLDataCell::bool_(true)); slice = try_slice(slice, 4); continue; } return Err(AnyPostgresError::UnsupportedArrayFormat); } else { - array.push(SQLDataCell { - tag: Tag::Bool, - value: Value { bool_: 1 }, - ..Default::default() - }); + array.push(SQLDataCell::bool_(true)); slice = try_slice(slice, 1); continue; } @@ -490,17 +406,9 @@ fn parse_array( array_type, types::Tag::date_array | types::Tag::timestamp_array | types::Tag::timestamptz_array ) { - array.push(SQLDataCell { - tag: Tag::Date, - value: Value { date: f64::INFINITY }, - ..Default::default() - }); + array.push(SQLDataCell::date(f64::INFINITY)); } else { - array.push(SQLDataCell { - tag: Tag::Float8, - value: Value { float8: f64::INFINITY }, - ..Default::default() - }); + array.push(SQLDataCell::float8(f64::INFINITY)); } slice = try_slice(slice, 8); continue; @@ -587,17 +495,9 @@ fn parse_array( | types::Tag::timestamp_array | types::Tag::timestamptz_array ) { - array.push(SQLDataCell { - tag: Tag::Date, - value: Value { date: val }, - ..Default::default() - }); + array.push(SQLDataCell::date(val)); } else { - array.push(SQLDataCell { - tag: Tag::Float8, - value: Value { float8: val }, - ..Default::default() - }); + array.push(SQLDataCell::float8(val)); } advance_after = Some(8 + (is_negative as usize)); break; @@ -621,67 +521,38 @@ fn parse_array( } let element = &slice[0..current_idx]; if is_float || array_type == types::Tag::float8_array { - array.push(SQLDataCell { - tag: Tag::Float8, - value: Value { - float8: bun_core::parse_double(element).unwrap_or(f64::NAN), - }, - ..Default::default() - }); + array.push(SQLDataCell::float8( + bun_core::parse_double(element).unwrap_or(f64::NAN), + )); slice = try_slice(slice, current_idx); continue; } match array_type { types::Tag::int8_array => { if bigint { - array.push(SQLDataCell { - tag: Tag::Int8, - value: Value { - int8: bun_core::fmt::parse_decimal::(element) - .ok_or(AnyPostgresError::UnsupportedArrayFormat)?, - }, - ..Default::default() - }); + array.push(SQLDataCell::int8( + bun_core::fmt::parse_decimal::(element) + .ok_or(AnyPostgresError::UnsupportedArrayFormat)?, + )); } else { - array.push(SQLDataCell { - tag: Tag::String, - value: Value { - string: if !element.is_empty() { - BunString::clone_utf8(element).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }); + array.push(SQLDataCell::string(element)); } slice = try_slice(slice, current_idx); continue; } types::Tag::cid_array | types::Tag::xid_array | types::Tag::oid_array => { - array.push(SQLDataCell { - tag: Tag::Uint4, - value: Value { - uint4: bun_core::fmt::parse_decimal::(element).unwrap_or(0), - }, - ..Default::default() - }); + array.push(SQLDataCell::uint4( + bun_core::fmt::parse_decimal::(element).unwrap_or(0), + )); slice = try_slice(slice, current_idx); continue; } _ => { + // @intCast(value) — i32 → i32, identity here let value = bun_core::fmt::parse_decimal::(element) .ok_or(AnyPostgresError::UnsupportedArrayFormat)?; - array.push(SQLDataCell { - tag: Tag::Int4, - value: Value { - // @intCast(value) — i32 → i32, identity here - int4: value, - }, - ..Default::default() - }); + array.push(SQLDataCell::int4(value)); slice = try_slice(slice, current_idx); continue; } @@ -879,103 +750,58 @@ pub(crate) fn from_bytes( } T::int2 => { if binary { - Ok(SQLDataCell { - tag: Tag::Int4, - value: Value { int4: parse_binary_int2(bytes)? as i32 }, - ..Default::default() - }) + Ok(SQLDataCell::int4(parse_binary_int2(bytes)? as i32)) } else { - Ok(SQLDataCell { - tag: Tag::Int4, - value: Value { int4: bun_core::fmt::parse_decimal::(bytes).unwrap_or(0) }, - ..Default::default() - }) + Ok(SQLDataCell::int4( + bun_core::fmt::parse_decimal::(bytes).unwrap_or(0), + )) } } T::cid | T::xid | T::oid => { if binary { - Ok(SQLDataCell { - tag: Tag::Uint4, - value: Value { uint4: parse_binary_oid(bytes)? }, - ..Default::default() - }) + Ok(SQLDataCell::uint4(parse_binary_oid(bytes)?)) } else { - Ok(SQLDataCell { - tag: Tag::Uint4, - value: Value { uint4: bun_core::fmt::parse_decimal::(bytes).unwrap_or(0) }, - ..Default::default() - }) + Ok(SQLDataCell::uint4( + bun_core::fmt::parse_decimal::(bytes).unwrap_or(0), + )) } } T::int4 => { if binary { - Ok(SQLDataCell { - tag: Tag::Int4, - value: Value { int4: parse_binary_int4(bytes)? }, - ..Default::default() - }) + Ok(SQLDataCell::int4(parse_binary_int4(bytes)?)) } else { - Ok(SQLDataCell { - tag: Tag::Int4, - value: Value { int4: bun_core::fmt::parse_decimal::(bytes).unwrap_or(0) }, - ..Default::default() - }) + Ok(SQLDataCell::int4( + bun_core::fmt::parse_decimal::(bytes).unwrap_or(0), + )) } } // postgres when reading bigint as int8 it returns a string unless type: { bigint: postgres.BigInt is set T::int8 => { if bigint { // .int8 is a 64-bit integer always string - Ok(SQLDataCell { - tag: Tag::Int8, - value: Value { int8: bun_core::fmt::parse_decimal::(bytes).unwrap_or(0) }, - ..Default::default() - }) + Ok(SQLDataCell::int8( + bun_core::fmt::parse_decimal::(bytes).unwrap_or(0), + )) } else { - Ok(SQLDataCell { - tag: Tag::String, - value: Value { - string: if !bytes.is_empty() { - BunString::clone_utf8(bytes).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(bytes)) } } T::float8 => { if binary && bytes.len() == 8 { - Ok(SQLDataCell { - tag: Tag::Float8, - value: Value { float8: parse_binary_float8(bytes)? }, - ..Default::default() - }) + Ok(SQLDataCell::float8(parse_binary_float8(bytes)?)) } else { - let float8: f64 = bun_core::parse_double(bytes).unwrap_or(f64::NAN); - Ok(SQLDataCell { - tag: Tag::Float8, - value: Value { float8 }, - ..Default::default() - }) + Ok(SQLDataCell::float8( + bun_core::parse_double(bytes).unwrap_or(f64::NAN), + )) } } T::float4 => { if binary && bytes.len() == 4 { - Ok(SQLDataCell { - tag: Tag::Float8, - value: Value { float8: parse_binary_float4(bytes)? as f64 }, - ..Default::default() - }) + Ok(SQLDataCell::float8(parse_binary_float4(bytes)? as f64)) } else { - let float4: f64 = bun_core::parse_double(bytes).unwrap_or(f64::NAN); - Ok(SQLDataCell { - tag: Tag::Float8, - value: Value { float8: float4 }, - ..Default::default() - }) + Ok(SQLDataCell::float8( + bun_core::parse_double(bytes).unwrap_or(f64::NAN), + )) } } T::numeric => { @@ -987,86 +813,37 @@ pub(crate) fn from_bytes( // if is binary format lets display as a string because JS cant handle it in a safe way let result = parse_binary_numeric(bytes, &mut numeric_buffer) .map_err(|_| AnyPostgresError::UnsupportedNumericFormat)?; - Ok(SQLDataCell { - tag: Tag::String, - value: Value { - string: BunString::clone_utf8(result.slice()).leak_wtf_impl(), - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(result.slice())) } else { // nice text is actually what we want here - Ok(SQLDataCell { - tag: Tag::String, - value: Value { - string: if !bytes.is_empty() { - BunString::clone_utf8(bytes).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(bytes)) } } - T::jsonb | T::json => Ok(SQLDataCell { - tag: Tag::Json, - value: Value { - json: if !bytes.is_empty() { - BunString::clone_utf8(bytes).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }), + T::jsonb | T::json => Ok(SQLDataCell::json(bytes)), T::bool => { if binary { - Ok(SQLDataCell { - tag: Tag::Bool, - value: Value { bool_: (!bytes.is_empty() && bytes[0] == 1) as u8 }, - ..Default::default() - }) + Ok(SQLDataCell::bool_(!bytes.is_empty() && bytes[0] == 1)) } else { - Ok(SQLDataCell { - tag: Tag::Bool, - value: Value { bool_: (!bytes.is_empty() && bytes[0] == b't') as u8 }, - ..Default::default() - }) + Ok(SQLDataCell::bool_(!bytes.is_empty() && bytes[0] == b't')) } } tag @ (T::date | T::timestamp | T::timestamptz) => { if bytes.is_empty() { - return Ok(SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..Default::default() - }); + return Ok(SQLDataCell::null()); } if binary && bytes.len() == 8 { match tag { - T::timestamptz => Ok(SQLDataCell { - tag: Tag::DateWithTimeZone, - value: Value { date_with_time_zone: crate::postgres::types::date::from_binary(bytes) }, - ..Default::default() - }), - T::timestamp => Ok(SQLDataCell { - tag: Tag::Date, - value: Value { date: crate::postgres::types::date::from_binary(bytes) }, - ..Default::default() - }), + T::timestamptz => Ok(SQLDataCell::date_with_tz( + crate::postgres::types::date::from_binary(bytes), + )), + T::timestamp => Ok(SQLDataCell::date( + crate::postgres::types::date::from_binary(bytes), + )), _ => unreachable!(), } } else { if bun_core::strings::eql_case_insensitive_ascii(bytes, b"NULL", true) { - return Ok(SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..Default::default() - }); + return Ok(SQLDataCell::null()); } // `timestamp` (without time zone) text carries no offset, so // decode its components as UTC to match the binary path. `date` @@ -1084,20 +861,12 @@ pub(crate) fn from_bytes( .map_err(crate::jsc::js_error_to_postgres)? } }; - Ok(SQLDataCell { - tag: Tag::Date, - value: Value { date }, - ..Default::default() - }) + Ok(SQLDataCell::date(date)) } } tag @ (T::time | T::timetz) => { if bytes.is_empty() { - return Ok(SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..Default::default() - }); + return Ok(SQLDataCell::null()); } if binary { if tag == T::time && bytes.len() == 8 { @@ -1108,14 +877,7 @@ pub(crate) fn from_bytes( let mut buffer = [0u8; 32]; let len = Postgres__formatTime(microseconds, &mut buffer, 32); - Ok(SQLDataCell { - tag: Tag::String, - value: Value { - string: BunString::clone_utf8(&buffer[0..len]).leak_wtf_impl(), - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(&buffer[0..len])) } else if tag == T::timetz && bytes.len() == 12 { // PostgreSQL sends timetz as microseconds since midnight (8 bytes) + timezone offset in seconds (4 bytes) let microseconds = i64::from_ne_bytes(bytes[0..8].try_into().expect("infallible: size matches")).swap_bytes(); @@ -1125,31 +887,13 @@ pub(crate) fn from_bytes( let mut buffer = [0u8; 48]; let len = Postgres__formatTimeTz(microseconds, tz_offset_seconds, &mut buffer, 48); - Ok(SQLDataCell { - tag: Tag::String, - value: Value { - string: BunString::clone_utf8(&buffer[0..len]).leak_wtf_impl(), - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(&buffer[0..len])) } else { Err(AnyPostgresError::InvalidBinaryData) } } else { // Text format - just return as string - Ok(SQLDataCell { - tag: Tag::String, - value: Value { - string: if !bytes.is_empty() { - BunString::clone_utf8(bytes).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }) + Ok(SQLDataCell::string(bytes)) } } @@ -1216,18 +960,7 @@ pub(crate) fn from_bytes( | T::timestamp_array | T::timestamptz_array | T::interval_array) => parse_array(bytes, bigint, tag, global_object, None, false, 0), - _ => Ok(SQLDataCell { - tag: Tag::String, - value: Value { - string: if !bytes.is_empty() { - BunString::clone_utf8(bytes).leak_wtf_impl() - } else { - core::ptr::null_mut() - }, - }, - free_value: 1, - ..Default::default() - }), + _ => Ok(SQLDataCell::string(bytes)), } } @@ -1480,25 +1213,15 @@ impl<'a> Putter<'a> { result_mode: PostgresSQLQueryResultMode, cached_structure: Option<&PostgresCachedStructure>, ) -> Result { - let mut names: *mut crate::jsc::ExternColumnIdentifier = core::ptr::null_mut(); - let mut names_count: u32 = 0; - if let Some(c) = cached_structure { - if let Some(f) = c.fields.as_ref() { - names = f.as_ptr().cast_mut(); - names_count = f.len() as u32; - } - } - - SQLDataCell::construct_object_from_data_cell( + SQLDataCell::to_js_object( global_object, array, structure, - self.list.as_mut_ptr(), + self.list, self.fields.len() as u32, flags, result_mode as u8, - names, - names_count, + cached_structure, ) .map_err(crate::jsc::js_error_to_postgres) } @@ -1533,7 +1256,7 @@ impl<'a> Putter<'a> { bun_core::scoped_log!(Postgres, "index: {}, oid: {}", index, oid); let cell: &mut SQLDataCell = &mut self.list[index as usize]; if IS_RAW { - *cell = SQLDataCell::raw(optional_bytes); + *cell = SQLDataCell::raw(optional_bytes.as_deref()); } else { let tag = if (types::short::MAX as u32) < oid { types::Tag::text @@ -1551,11 +1274,7 @@ impl<'a> Putter<'a> { self.global_object, )? } else { - SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..Default::default() - } + SQLDataCell::null() }; } self.count += 1; diff --git a/src/sql_jsc/postgres/PostgresRequest.rs b/src/sql_jsc/postgres/PostgresRequest.rs index f62e966b4c1..bb3cd6aa507 100644 --- a/src/sql_jsc/postgres/PostgresRequest.rs +++ b/src/sql_jsc/postgres/PostgresRequest.rs @@ -286,11 +286,29 @@ pub fn write_query( Ok(()) } +/// Writes the Execute + Flush + Sync tail shared by every prepared-statement +/// query path. +#[inline] +fn execute_flush_sync( + name: &[u8], + mut writer: protocol::NewWriter, +) -> Result<(), AnyPostgresError> { + let exec = protocol::Execute { + p: protocol::PortalOrPreparedStatement::PreparedStatement(name), + ..Default::default() + }; + exec.write_internal(&mut writer)?; + + writer.write(&protocol::FLUSH)?; + writer.write(&protocol::SYNC)?; + Ok(()) +} + pub(crate) fn prepare_and_query_with_signature( global: &JSGlobalObject, query: &[u8], array_value: JSValue, - mut writer: protocol::NewWriter, + writer: protocol::NewWriter, signature: &mut Signature, ) -> Result<(), AnyPostgresError> { write_query( @@ -309,17 +327,7 @@ pub(crate) fn prepare_and_query_with_signature( &[], writer, )?; - let exec = protocol::Execute { - p: protocol::PortalOrPreparedStatement::PreparedStatement( - &signature.prepared_statement_name, - ), - ..Default::default() - }; - exec.write_internal(&mut writer)?; - - writer.write(&protocol::FLUSH)?; - writer.write(&protocol::SYNC)?; - Ok(()) + execute_flush_sync(&signature.prepared_statement_name, writer) } pub(crate) fn bind_and_execute( @@ -327,7 +335,7 @@ pub(crate) fn bind_and_execute( statement: &PostgresSQLStatement, array_value: JSValue, columns_value: JSValue, - mut writer: protocol::NewWriter, + writer: protocol::NewWriter, ) -> Result<(), AnyPostgresError> { write_bind( &statement.signature.prepared_statement_name, @@ -339,17 +347,7 @@ pub(crate) fn bind_and_execute( &statement.fields, writer, )?; - let exec = protocol::Execute { - p: protocol::PortalOrPreparedStatement::PreparedStatement( - &statement.signature.prepared_statement_name, - ), - ..Default::default() - }; - exec.write_internal(&mut writer)?; - - writer.write(&protocol::FLUSH)?; - writer.write(&protocol::SYNC)?; - Ok(()) + execute_flush_sync(&statement.signature.prepared_statement_name, writer) } /// Atomically sends Parse + [Describe] + Bind + Execute + Flush + Sync as a single message batch. @@ -368,8 +366,10 @@ pub fn parse_and_bind_and_execute( ) -> Result<(), AnyPostgresError> { let name = &statement.signature.prepared_statement_name; - // Parse - { + if include_describe { + // Parse + Describe (needed on first execution to learn parameter/result types for caching) + write_query(query, name, &statement.signature.fields, writer)?; + } else { let q = protocol::Parse { name, params: &statement.signature.fields, @@ -379,15 +379,6 @@ pub fn parse_and_bind_and_execute( bun_core::scoped_log!(Postgres, "Parse: {}", bun_fmt::quote(query)); } - // Describe (needed on first execution to learn parameter/result types for caching) - if include_describe { - let d = protocol::Describe { - p: protocol::PortalOrPreparedStatement::PreparedStatement(name), - }; - d.write_internal(writer)?; - bun_core::scoped_log!(Postgres, "Describe: {}", bun_fmt::quote(name)); - } - // Bind — use server-provided types if available (binary format), otherwise // fall back to signature types (text format for unknowns). The server will // handle text-to-type conversion based on the parameter types from Parse. @@ -409,16 +400,7 @@ pub fn parse_and_bind_and_execute( writer, )?; - // Execute - let exec = protocol::Execute { - p: protocol::PortalOrPreparedStatement::PreparedStatement(name), - ..Default::default() - }; - exec.write_internal(&mut writer)?; - - writer.write(&protocol::FLUSH)?; - writer.write(&protocol::SYNC)?; - Ok(()) + execute_flush_sync(name, writer) } pub(crate) fn execute_query( diff --git a/src/sql_jsc/postgres/PostgresSQLConnection.rs b/src/sql_jsc/postgres/PostgresSQLConnection.rs index 87cab2555b1..a65d6e35938 100644 --- a/src/sql_jsc/postgres/PostgresSQLConnection.rs +++ b/src/sql_jsc/postgres/PostgresSQLConnection.rs @@ -7,8 +7,8 @@ use core::sync::atomic::{AtomicU32, Ordering}; use crate::jsc::EventLoopTimer; use crate::jsc::webcore::AutoFlusher; use crate::jsc::{ - self as jsc, CallFrame, HasAutoFlush, JSGlobalObject, JSGlobalObjectSqlExt as _, JSValue, - JsResult, VirtualMachine, VirtualMachineSqlExt as _, + self as jsc, CallFrame, HasAutoFlush, JSGlobalObject, JSValue, JsResult, VirtualMachine, + VirtualMachineSqlExt as _, }; use bun_boringssl as BoringSSL; use bun_collections::{HashMap, OffsetByteList, StringMap}; @@ -31,6 +31,7 @@ use crate::postgres::postgres_sql_query::{self, Status as QueryStatus}; use crate::postgres::postgres_sql_statement::{Error as StatementError, Status as StatementStatus}; use crate::postgres::sasl::SASLStatus; use crate::shared::CachedStructure as PostgresCachedStructure; +use crate::shared::connection_args; use bun_sql::postgres::AnyPostgresError; use bun_sql::postgres::PostgresErrorOptions; use bun_sql::postgres::PostgresProtocol as protocol; @@ -183,7 +184,7 @@ impl PostgresSQLConnection { /// Read-modify-write the packed `Cell` through `&self`. #[inline] - fn update_flags(&self, f: impl FnOnce(&mut ConnectionFlags)) { + pub(crate) fn update_flags(&self, f: impl FnOnce(&mut ConnectionFlags)) { let mut v = self.flags.get(); f(&mut v); self.flags.set(v); @@ -427,73 +428,37 @@ impl PostgresSQLConnection { // Zig: `this.socket.SocketTCP.socket.connected` — at this point we are // a plain TCP socket in the Connected state. - let Socket::SocketTcp(tcp) = self.socket.get() else { - self.fail( - b"Failed to upgrade to TLS", - AnyPostgresError::TLSUpgradeFailed, - ); - return; - }; - let uws::InternalSocket::Connected(raw) = tcp.socket else { - self.fail( - b"Failed to upgrade to TLS", - AnyPostgresError::TLSUpgradeFailed, - ); - return; - }; - - // SAFETY: `secure` is set to a live `SSL_CTX*` before `setup_tls` is - // reached (Zig: `this.secure.?`). - let ssl_ctx = unsafe { - &mut *self - .secure - .expect("secure SSL_CTX must be set before setupTLS") - }; - let server_name = self.tls_config.server_name(); - let sni = if server_name.is_null() { - None - } else { - // SAFETY: `server_name` is a NUL-terminated C string owned by - // `tls_config` for the connection lifetime. - Some(unsafe { bun_core::ffi::cstr(server_name) }) + let raw = match self.socket.get() { + Socket::SocketTcp(tcp) => tcp.socket.get(), + _ => None, }; - // Zig: `@sizeOf(?*PostgresSQLConnection)` — `?*T` is an 8-byte null-niche - // optional. The Rust layout-equivalent is `Option>`; using - // `Option<*mut T>` here would request 16 bytes (separate discriminant) - // and desync with the trampoline reader (uws_handlers.rs) which reads - // the slot as `Option>`. - let ext_size = - core::mem::size_of::>>() as i32; - - // SAFETY: `raw` is a live connected `us_socket_t*`; adopt_tls may - // realloc and return a different ptr. - let Some(new_socket) = (unsafe { &mut *raw }).adopt_tls( - tls_group, - bun_uws::SocketKind::PostgresTls, - ssl_ctx, - sni, - ext_size, - ext_size, - ) else { + // Zig: `ext(?*PostgresSQLConnection).* = this`. + let adopted = raw.is_some_and(|raw| { + // SAFETY: `raw` is a live connected `us_socket_t*`; `secure` is set + // to a live `SSL_CTX*` before `setup_tls` is reached (Zig: + // `this.secure.?`); `server_name` is a NUL-terminated C string + // owned by `tls_config` for the connection lifetime. + unsafe { + jsc::adopt_socket_tls( + raw, + tls_group, + bun_uws::SocketKind::PostgresTls, + &mut *self + .secure + .expect("secure SSL_CTX must be set before setupTLS"), + &self.tls_config, + self.as_ctx_ptr(), + |s| self.socket.set(s), + ) + } + }); + if !adopted { self.fail( b"Failed to upgrade to TLS", AnyPostgresError::TLSUpgradeFailed, ); return; - }; - let new_socket = new_socket.as_ptr(); - // SAFETY: `new_socket` is a live us_socket_t freshly returned by - // `adopt_tls`; ext slot is sized for `Option>` - // above. One `&mut` reborrow drives both safe inherent methods - // (`ext` / `start_tls_handshake`). Zig: `ext(?*PostgresSQLConnection).* = this`. - let sock = unsafe { &mut *new_socket }; - *sock.ext::>>() = - core::ptr::NonNull::new(self.as_ctx_ptr()); - self.socket.set(Socket::SocketTls(uws::SocketTLS { - socket: uws::InternalSocket::Connected(new_socket), - })); - // ext is now repointed; safe to kick the handshake (any dispatch lands here). - sock.start_tls_handshake(); + } self.start(); } @@ -845,57 +810,20 @@ impl PostgresSQLConnection { pub fn on_handshake(&self, success: i32, ssl_error: uws::us_bun_verify_error_t) { debug!("onHandshake: {} {}", success, ssl_error.error_no); - let handshake_success = success == 1; - if handshake_success { - if self.tls_config.reject_unauthorized() != 0 { - // only reject the connection if reject_unauthorized == true - match self.ssl_mode { - // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 - SSLMode::VerifyCa | SSLMode::VerifyFull => { - if ssl_error.error_no != 0 { - let Ok(v) = verify_error_to_js(&ssl_error, self.global()) else { - return; - }; - self.fail_with_js_value(v); - return; - } - - if self.ssl_mode == SSLMode::VerifyFull { - let servername = self.tls_config.server_name(); - let ok = if servername.is_null() { - false - } else { - // SAFETY: native handle of a connected TLS socket is `SSL*`. - let ssl_ptr: *mut BoringSSL::c::SSL = self - .socket - .get() - .get_native_handle() - .map_or(core::ptr::null_mut(), |p| p.cast()); - // SAFETY: `servername` is a NUL-terminated C string owned by `tls_config`. - let hostname = - unsafe { bun_core::ffi::cstr(servername) }.to_bytes(); - // SAFETY: `ssl_ptr` is the live SSL* of a connected TLS socket. - !ssl_ptr.is_null() - && BoringSSL::check_server_identity( - unsafe { &mut *ssl_ptr }, - hostname, - ) - }; - if !ok { - let Ok(v) = verify_error_to_js(&ssl_error, self.global()) else { - return; - }; - self.fail_with_js_value(v); - } - } - } - // require is the same as prefer - SSLMode::Require | SSLMode::Prefer | SSLMode::Disable => {} - } - } - } else { - // if we are here is because server rejected us, and the error_no is the cause of this - // no matter if reject_unauthorized is false because we are disconnected by the server + // https://github.com/porsager/postgres/blob/6ec85a432b17661ccacbdf7f765c651e88969d36/src/connection.js#L272-L279 + // only reject the connection if reject_unauthorized == true (require is the same as prefer); + // but if `success != 1` the server rejected us no matter what reject_unauthorized says. + let verify = self.tls_config.reject_unauthorized() != 0 + && matches!(self.ssl_mode, SSLMode::VerifyCa | SSLMode::VerifyFull); + let ok = success == 1 + && (!verify + || connection_args::verify_tls_server( + self.ssl_mode == SSLMode::VerifyFull, + &self.tls_config, + self.socket.get().get_native_handle(), + ssl_error.error_no, + )); + if !ok { let Ok(v) = verify_error_to_js(&ssl_error, self.global()) else { return; }; @@ -1082,80 +1010,16 @@ pub(crate) fn call(global_object: &JSGlobalObject, callframe: &CallFrame) -> JsR // is the canonical safe escape hatch (one audited unsafe in bun_jsc) for // `&mut self` helpers like `ssl_ctx_cache()` / `postgres_socket_group()`. let vm = global_object.bun_vm().as_mut(); - let arguments = callframe.arguments(); - let hostname_str = bun_core::OwnedString::new(arguments[0].to_bun_string(global_object)?); - let port = arguments[1].coerce::(global_object)?; - - let username_str = bun_core::OwnedString::new(arguments[2].to_bun_string(global_object)?); - let password_str = bun_core::OwnedString::new(arguments[3].to_bun_string(global_object)?); - let database_str = bun_core::OwnedString::new(arguments[4].to_bun_string(global_object)?); - let ssl_mode: SSLMode = match arguments[5].to_int32() { - 0 => SSLMode::Disable, - 1 => SSLMode::Prefer, - 2 => SSLMode::Require, - 3 => SSLMode::VerifyCa, - 4 => SSLMode::VerifyFull, - _ => SSLMode::Disable, + // Args 0..=14 (hostname/port/credentials/sslMode/tls/options/path/ + // callbacks/timeouts) are decoded by the shared helper, which also builds + // the TLS `SSL_CTX` and returns it inside the `args.tls` errdefer guard. + // Ownership of `(secure, tls_config)` passes into `ptr.*` once allocated — + // `into_inner` recovers them just before the Box is built so the + // connect-fail path's `ptr.deinit()` is the sole cleanup. + let Some(args) = connection_args::parse::(vm, global_object, callframe)? else { + return Ok(JSValue::ZERO); }; - let tls_object = arguments[6]; - - let mut tls_config: jsc::api::ServerConfig::SSLConfig = Default::default(); - let mut secure: Option<*mut uws::SslCtx> = None; - if ssl_mode != SSLMode::Disable { - tls_config = if tls_object.is_boolean() && tls_object.to_boolean() { - Default::default() - } else if tls_object.is_object() { - match jsc::api::ServerConfig::SSLConfig::from_js(&mut *vm, global_object, tls_object) { - Ok(opt) => opt.unwrap_or_default(), - Err(_) => return Ok(JSValue::ZERO), - } - } else { - return Err(global_object - .throw_invalid_arguments(format_args!("tls must be a boolean or an object"))); - }; - - if global_object.has_exception() { - drop(tls_config); - return Ok(JSValue::ZERO); - } - - // We always request the cert so we can verify it and also we manually - // abort the connection if the hostname doesn't match. Built here (not - // at STARTTLS time) so cert/CA errors throw synchronously. Goes - // through the per-VM weak `SSLContextCache` so every connection in the - // pool — and every reconnect — shares one `SSL_CTX*` per distinct - // config instead of building a fresh one per `PostgresSQLConnection`. - let mut err: uws::create_bun_socket_error_t = uws::create_bun_socket_error_t::none; - secure = vm - .ssl_ctx_cache() - .get_or_create_opts(&tls_config.as_usockets_for_client_verification(), &mut err); - if secure.is_none() { - drop(tls_config); - // TODO(port): Zig `err.toJS(globalObject)` — `to_js` lives as an extension - // in the runtime _jsc crate and isn't reachable from sql_jsc; throw the - // static message instead. - return Err(global_object.throw(format_args!( - "{}", - bun_core::fmt::s(err.message().unwrap_or(b"Failed to create SSL context")) - ))); - } - } - // Covers `try arguments[7/8].toBunString()` and the null-byte rejection - // below. Ownership passes into `ptr.*` once allocated — `into_inner` - // recovers them just before the Box is built so the connect-fail path's - // `ptr.deinit()` is the sole cleanup. - // PORT NOTE: guard owns `(secure, tls_config)` by value. Do NOT - // `drop_in_place` a stack local that Rust would also auto-drop on unwind — - // that double-frees. The closure's `_tls_config` is dropped exactly once by - // normal scope-exit drop here. - let errdefer_guard = scopeguard::guard((secure, tls_config), |(secure, _tls_config)| { - if let Some(s) = secure { - // SAFETY: SSL_CTX_free is safe to call on a valid SSL_CTX*. - unsafe { BoringSSL::c::SSL_CTX_free(s) }; - } - }); - // PORT NOTE: `StringBuilder::append` takes `&mut self` and returns a borrow // of the backing buffer, so successive appends can't keep their `&[u8]` // results live across each other. The buffer is allocated once and never @@ -1168,77 +1032,33 @@ pub(crate) fn call(global_object: &JSGlobalObject, callframe: &CallFrame) -> JsR let options: bun_ptr::RawSlice; let path: bun_ptr::RawSlice; - let options_str = bun_core::OwnedString::new(arguments[7].to_bun_string(global_object)?); - - let path_str = bun_core::OwnedString::new(arguments[8].to_bun_string(global_object)?); - let options_buf: Box<[u8]> = 'brk: { let mut b = bun_core::StringBuilder::default(); - b.cap += username_str.utf8_byte_length() + b.cap += args.username.slice().len() + 1 - + password_str.utf8_byte_length() + + args.password.slice().len() + 1 - + database_str.utf8_byte_length() + + args.database.slice().len() + 1 - + options_str.utf8_byte_length() + + args.options.slice().len() + 1 - + path_str.utf8_byte_length() + + args.path.slice().len() + 1; let _ = b.allocate(); - let u = username_str.to_utf8_without_ref(); - username = bun_ptr::RawSlice::new(b.append(u.slice())); - drop(u); - - let p = password_str.to_utf8_without_ref(); - password = bun_ptr::RawSlice::new(b.append(p.slice())); - drop(p); - - let d = database_str.to_utf8_without_ref(); - database = bun_ptr::RawSlice::new(b.append(d.slice())); - drop(d); - - let o = options_str.to_utf8_without_ref(); - options = bun_ptr::RawSlice::new(b.append(o.slice())); - drop(o); - - let _path = path_str.to_utf8_without_ref(); - path = bun_ptr::RawSlice::new(b.append(_path.slice())); - drop(_path); + username = bun_ptr::RawSlice::new(b.append(args.username.slice())); + password = bun_ptr::RawSlice::new(b.append(args.password.slice())); + database = bun_ptr::RawSlice::new(b.append(args.database.slice())); + options = bun_ptr::RawSlice::new(b.append(args.options.slice())); + path = bun_ptr::RawSlice::new(b.append(args.path.slice())); break 'brk b.move_to_slice(); }; - // Reject null bytes in connection parameters to prevent Postgres startup - // message parameter injection (null bytes act as field terminators in the - // wire protocol's key\0value\0 format). - for (entry, name) in [ - (username, &b"username"[..]), - (password, b"password"), - (database, b"database"), - (path, b"path"), - ] { - let entry = entry.slice(); - if !entry.is_empty() && entry.contains(&0) { - drop(options_buf); - // tls_config / secure released by the errdefer above. - return global_object.throw_invalid_arguments_fmt(format_args!( - "{} must not contain null bytes", - bstr::BStr::new(name) - )); - } - } - - let on_connect = arguments[9]; - let on_close = arguments[10]; - let idle_timeout = arguments[11].to_int32(); - let connection_timeout = arguments[12].to_int32(); - let max_lifetime = arguments[13].to_int32(); - let use_unnamed_prepared_statements = arguments[14].as_boolean(); - + let ssl_mode = args.ssl_mode; // Ownership transferred into `ptr`; disarm the errdefer and recover the // moved `secure`/`tls_config` for the struct literal below. - let (secure, tls_config) = scopeguard::ScopeGuard::into_inner(errdefer_guard); + let (secure, tls_config) = scopeguard::ScopeGuard::into_inner(args.tls); let ptr: *mut PostgresSQLConnection = bun_core::heap::into_raw(Box::new(PostgresSQLConnection { @@ -1281,9 +1101,9 @@ pub(crate) fn call(global_object: &JSGlobalObject, callframe: &CallFrame) -> JsR TLSStatus::None }), ssl_mode, - idle_timeout_interval_ms: u32::try_from(idle_timeout).expect("int cast"), - connection_timeout_ms: u32::try_from(connection_timeout).expect("int cast"), - flags: Cell::new(if use_unnamed_prepared_statements { + idle_timeout_interval_ms: u32::try_from(args.idle_timeout).expect("int cast"), + connection_timeout_ms: u32::try_from(args.connection_timeout).expect("int cast"), + flags: Cell::new(if args.use_unnamed_prepared_statements { ConnectionFlags::USE_UNNAMED_PREPARED_STATEMENTS } else { ConnectionFlags::empty() @@ -1291,7 +1111,7 @@ pub(crate) fn call(global_object: &JSGlobalObject, callframe: &CallFrame) -> JsR timer: JsCell::new(EventLoopTimer::init_paused( EventLoopTimerTag::PostgresSQLConnectionTimeout, )), - max_lifetime_interval_ms: u32::try_from(max_lifetime).expect("int cast"), + max_lifetime_interval_ms: u32::try_from(args.max_lifetime).expect("int cast"), max_lifetime_timer: JsCell::new(EventLoopTimer::init_paused( EventLoopTimerTag::PostgresSQLConnectionMaxLifetime, )), @@ -1304,8 +1124,6 @@ pub(crate) fn call(global_object: &JSGlobalObject, callframe: &CallFrame) -> JsR let this = ParentRef::from(core::ptr::NonNull::new(ptr).expect("heap::into_raw non-null")); { - let hostname = hostname_str.to_utf8(); - // Postgres always opens plain TCP first (SSLRequest happens in-band), // so even `ssl_mode != .disable` lands in the TCP group; `setupTLS()` // adopts into `postgres_tls_group` after the server's `S`. @@ -1325,8 +1143,8 @@ pub(crate) fn call(global_object: &JSGlobalObject, callframe: &CallFrame) -> JsR group, uws::SocketKind::Postgres, None, - hostname.slice(), - port, + args.hostname.to_utf8().slice(), + args.port, ptr, false, ) @@ -1350,8 +1168,8 @@ pub(crate) fn call(global_object: &JSGlobalObject, callframe: &CallFrame) -> JsR let js_value = js::to_js(ptr, global_object); js_value.ensure_still_alive(); this.js_value.set(crate::jsc::JsRef::init_weak(js_value)); - js::onconnect_set_cached(js_value, global_object, on_connect); - js::onclose_set_cached(js_value, global_object, on_close); + js::onconnect_set_cached(js_value, global_object, args.on_connect); + js::onclose_set_cached(js_value, global_object, args.on_close); /* TODO(port): bun_core::analytics::Features::POSTGRES_CONNECTIONS counter */ Ok(js_value) } @@ -1675,13 +1493,18 @@ pub struct Writer { pub connection: BackRef, } -impl Writer { - // `write_buffer` is a `JsCell`; route mutation through the safe - // closure-scoped `with_mut` and reads through `get()` so the backref - // deref stays inside `BackRef`'s safe `Deref` — no raw `get_mut` - // escape hatch needed. +// `write_buffer` is a `JsCell`; route mutation through the safe +// closure-scoped `with_mut` and reads through `get()` so the backref +// deref stays inside `BackRef`'s safe `Deref` — no raw `get_mut` +// escape hatch needed. +impl protocol::WriterContext for Writer { + #[inline] + fn offset(self) -> usize { + self.connection.write_buffer.get().len() as usize + } - pub fn write(&mut self, data: &[u8]) -> Result<(), AnyPostgresError> { + #[inline] + fn write(self, data: &[u8]) -> Result<(), AnyPostgresError> { self.connection .write_buffer .with_mut(|b| b.write(data)) @@ -1689,31 +1512,13 @@ impl Writer { Ok(()) } - pub fn pwrite(&mut self, data: &[u8], index: usize) -> Result<(), AnyPostgresError> { + #[inline] + fn pwrite(self, data: &[u8], index: usize) -> Result<(), AnyPostgresError> { self.connection.write_buffer.with_mut(|b| { b.byte_list.slice_mut()[index..][..data.len()].copy_from_slice(data); }); Ok(()) } - - pub fn offset(self) -> usize { - self.connection.write_buffer.get().len() as usize - } -} - -impl protocol::WriterContext for Writer { - #[inline] - fn offset(self) -> usize { - Writer::offset(self) - } - #[inline] - fn write(mut self, bytes: &[u8]) -> Result<(), AnyPostgresError> { - Writer::write(&mut self, bytes) - } - #[inline] - fn pwrite(mut self, bytes: &[u8], i: usize) -> Result<(), AnyPostgresError> { - Writer::pwrite(&mut self, bytes, i) - } } impl PostgresSQLConnection { @@ -1745,31 +1550,38 @@ impl Reader { self.connection.read_buffer.get() } - pub(crate) fn mark_message_start(&mut self) { - let head = self.read_buffer().head; - self.connection.last_message_start.set(head); + fn ensure_capacity(self, count: usize) -> bool { + let buf = self.read_buffer(); + (buf.head as usize) + count <= buf.byte_list.len() } +} - pub(crate) fn ensure_length(self, count: usize) -> bool { - self.ensure_capacity(count) +impl protocol::ReaderContext for Reader { + #[inline] + fn mark_message_start(&mut self) { + let head = self.read_buffer().head; + self.connection.last_message_start.set(head); } - pub(crate) fn peek(&self) -> &[u8] { + #[inline] + fn peek(&self) -> &[u8] { self.read_buffer().remaining() } - pub(crate) fn skip(&mut self, count: usize) { + #[inline] + fn skip(&mut self, count: usize) { self.connection.read_buffer.with_mut(|buf| { buf.head = (buf.head + (count as u32)).min(buf.byte_list.len() as u32); }); } - pub(crate) fn ensure_capacity(self, count: usize) -> bool { - let buf = self.read_buffer(); - (buf.head as usize) + count <= buf.byte_list.len() + #[inline] + fn ensure_length(&mut self, count: usize) -> bool { + self.ensure_capacity(count) } - pub(crate) fn read(&mut self, count: usize) -> Result { + #[inline] + fn read(&mut self, count: usize) -> Result { let remaining = self.read_buffer().remaining(); if remaining.len() < count { return Err(AnyPostgresError::ShortRead); @@ -1778,17 +1590,18 @@ impl Reader { // PORT NOTE: reshaped for borrowck — capture as `RawSlice` before calling // skip(); the read_buffer backing storage is not reallocated by skip(). let slice = bun_ptr::RawSlice::new(&remaining[..count]); - self.skip(count); + protocol::ReaderContext::skip(self, count); Ok(Data::Temporary(slice)) } - pub(crate) fn read_z(&mut self) -> Result { + #[inline] + fn read_z(&mut self) -> Result { let remain = self.read_buffer().remaining(); if let Some(zero) = strings::index_of_char(remain, 0) { // `RawSlice` backref into read_buffer (not reallocated by skip()). let slice = bun_ptr::RawSlice::new(&remain[..zero as usize]); - self.skip(zero as usize + 1); + protocol::ReaderContext::skip(self, zero as usize + 1); return Ok(Data::Temporary(slice)); } @@ -1796,33 +1609,6 @@ impl Reader { } } -impl protocol::ReaderContext for Reader { - #[inline] - fn mark_message_start(&mut self) { - Reader::mark_message_start(self) - } - #[inline] - fn peek(&self) -> &[u8] { - Reader::peek(self) - } - #[inline] - fn skip(&mut self, count: usize) { - Reader::skip(self, count) - } - #[inline] - fn ensure_length(&mut self, count: usize) -> bool { - Reader::ensure_length(*self, count) - } - #[inline] - fn read(&mut self, count: usize) -> Result { - Reader::read(self, count) - } - #[inline] - fn read_z(&mut self) -> Result { - Reader::read_z(self) - } -} - impl PostgresSQLConnection { pub fn buffered_reader(&self) -> protocol::NewReader { protocol::NewReader { @@ -1868,6 +1654,57 @@ impl PostgresSQLConnection { } } + /// Report a failed protocol write on `req`: forward a pending JS exception + /// if one exists, otherwise (optionally) mark the statement as failed and + /// reject the query with the write error. + /// + /// `statement` is a raw pointer rather than `&mut` so that no protected + /// `&mut` borrow spans the JS-re-entrant `on_write_fail` call below — + /// re-entrant JS may reach the same cached statement via `statement_mut()`. + fn report_write_failure( + &self, + req: &PostgresSQLQuery, + err: AnyPostgresError, + statement: Option<*mut PostgresSQLStatement>, + ) { + if let Some(err_) = self.global().try_take_exception() { + req.on_js_error(err_, self.global()); + } else { + if let Some(statement) = statement { + // SAFETY: callers derive the pointer from `req.statement_mut()`, + // whose pointee is kept alive by the intrusive ref `req` holds. + // The `&mut` formed here ends before `on_write_fail` re-enters JS. + let statement = unsafe { &mut *statement }; + statement.status = StatementStatus::Failed; + statement.error_response = Some(StatementError::PostgresError(err)); + } + req.on_write_fail(err, self.global(), self.get_queries_array()); + } + } + + /// Remove a finished/failed request from the queue head, or — when it is + /// not at the head (`offset > 0`) — mark it failed so `advance`'s deferred + /// cleanup pass discards it later. `bump` advances `offset` past the + /// deferred entry. + #[inline] + fn discard_or_defer( + &self, + req_ptr: *mut PostgresSQLQuery, + req: &PostgresSQLQuery, + offset: &mut usize, + bump: bool, + ) { + if *offset == 0 { + self.discard_request(req_ptr); + } else { + // deinit later + req.status.set(QueryStatus::Fail); + if bump { + *offset += 1; + } + } + } + fn advance(&self) { let mut offset: usize = 0; debug!("advance"); @@ -1940,17 +1777,8 @@ impl PostgresSQLConnection { if let Err(err) = PostgresRequest::execute_query(query_str.slice(), self.writer()) { - if let Some(err_) = self.global().try_take_exception() { - req.on_js_error(err_, self.global()); - } else { - req.on_write_fail(err, self.global(), self.get_queries_array()); - } - if offset == 0 { - self.discard_request(req_ptr); - } else { - // deinit later - req.status.set(QueryStatus::Fail); - } + self.report_write_failure(&req, err, None); + self.discard_or_defer(req_ptr, &req, &mut offset, false); debug!("executeQuery failed: {}", err); continue; } @@ -1976,13 +1804,7 @@ impl PostgresSQLConnection { }; req.on_js_error(ev, self.global()); } - if offset == 0 { - self.discard_request(req_ptr); - } else { - // deinit later - req.status.set(QueryStatus::Fail); - offset += 1; - } + self.discard_or_defer(req_ptr, &req, &mut offset, true); continue; } StatementStatus::Prepared => { @@ -1991,13 +1813,7 @@ impl PostgresSQLConnection { false, "query value was freed earlier than expected" ); - if offset == 0 { - self.discard_request(req_ptr); - } else { - // deinit later - req.status.set(QueryStatus::Fail); - offset += 1; - } + self.discard_or_defer(req_ptr, &req, &mut offset, true); continue; }; let binding_value = @@ -2031,22 +1847,8 @@ impl PostgresSQLConnection { self.writer(), ) { - if let Some(err_) = self.global().try_take_exception() { - req.on_js_error(err_, self.global()); - } else { - req.on_write_fail( - err, - self.global(), - self.get_queries_array(), - ); - } - if offset == 0 { - self.discard_request(req_ptr); - } else { - // deinit later - req.status.set(QueryStatus::Fail); - offset += 1; - } + self.report_write_failure(&req, err, None); + self.discard_or_defer(req_ptr, &req, &mut offset, true); debug!( "parse, bind and execute failed: {}", <&'static str>::from(err) @@ -2063,22 +1865,8 @@ impl PostgresSQLConnection { columns_value, self.writer(), ) { - if let Some(err_) = self.global().try_take_exception() { - req.on_js_error(err_, self.global()); - } else { - req.on_write_fail( - err, - self.global(), - self.get_queries_array(), - ); - } - if offset == 0 { - self.discard_request(req_ptr); - } else { - // deinit later - req.status.set(QueryStatus::Fail); - offset += 1; - } + self.report_write_failure(&req, err, None); + self.discard_or_defer(req_ptr, &req, &mut offset, true); debug!("bind and execute failed: {}", err); continue; } @@ -2126,13 +1914,7 @@ impl PostgresSQLConnection { false, "query value was freed earlier than expected" ); - if offset == 0 { - self.discard_request(req_ptr); - } else { - // deinit later - req.status.set(QueryStatus::Fail); - offset += 1; - } + self.discard_or_defer(req_ptr, &req, &mut offset, true); continue; }; // prepareAndQueryWithSignature will write + bind + execute, it will change to running after binding is complete @@ -2150,24 +1932,17 @@ impl PostgresSQLConnection { &mut statement.signature, ) { - if let Some(err_) = self.global().try_take_exception() { - req.on_js_error(err_, self.global()); - } else { - statement.status = StatementStatus::Failed; - statement.error_response = - Some(StatementError::PostgresError(err)); - req.on_write_fail( - err, - self.global(), - self.get_queries_array(), - ); - } - if offset == 0 { - self.discard_request(req_ptr); - } else { - // deinit later - req.status.set(QueryStatus::Fail); - } + self.report_write_failure( + &req, + err, + Some(&raw mut *statement), + ); + self.discard_or_defer( + req_ptr, + &req, + &mut offset, + false, + ); debug!( "prepareAndQueryWithSignature failed: {}", <&'static str>::from(err) @@ -2224,18 +1999,11 @@ impl PostgresSQLConnection { self.writer(), ) { - if let Some(err_) = self.global().try_take_exception() { - req.on_js_error(err_, self.global()); - } else { - statement.status = StatementStatus::Failed; - statement.error_response = - Some(StatementError::PostgresError(err)); - req.on_write_fail( - err, - self.global(), - self.get_queries_array(), - ); - } + self.report_write_failure( + &req, + err, + Some(&raw mut *statement), + ); debug_assert!(offset == 0); self.discard_request(req_ptr); debug!( @@ -2270,36 +2038,22 @@ impl PostgresSQLConnection { &statement.signature.fields, connection_writer, ) { - if let Some(err_) = self.global().try_take_exception() { - req.on_js_error(err_, self.global()); - } else { - statement.error_response = - Some(StatementError::PostgresError(err)); - statement.status = StatementStatus::Failed; - req.on_write_fail( - err, - self.global(), - self.get_queries_array(), - ); - } + self.report_write_failure( + &req, + err, + Some(&raw mut *statement), + ); debug_assert!(offset == 0); self.discard_request(req_ptr); debug!("write query failed: {}", <&'static str>::from(err)); continue; } if let Err(err) = connection_writer.write(&protocol::SYNC) { - if let Some(err_) = self.global().try_take_exception() { - req.on_js_error(err_, self.global()); - } else { - statement.error_response = - Some(StatementError::PostgresError(err)); - statement.status = StatementStatus::Failed; - req.on_write_fail( - err, - self.global(), - self.get_queries_array(), - ); - } + self.report_write_failure( + &req, + err, + Some(&raw mut *statement), + ); debug_assert!(offset == 0); self.discard_request(req_ptr); debug!( @@ -2349,13 +2103,7 @@ impl PostgresSQLConnection { continue; } QueryStatus::Success => { - if offset > 0 { - // deinit later - req.status.set(QueryStatus::Fail); - offset += 1; - continue; - } - self.discard_request(req_ptr); + self.discard_or_defer(req_ptr, &req, &mut offset, true); continue; } QueryStatus::Fail => { diff --git a/src/sql_jsc/postgres/PostgresSQLContext.rs b/src/sql_jsc/postgres/PostgresSQLContext.rs index 7ca11a44092..69977971bd5 100644 --- a/src/sql_jsc/postgres/PostgresSQLContext.rs +++ b/src/sql_jsc/postgres/PostgresSQLContext.rs @@ -2,7 +2,9 @@ //! `us_socket_context_t` that used to live here is gone — connections link //! into `RareData.postgres_group`/`postgres_tls_group` instead. -use crate::jsc::{CallFrame, JSGlobalObject, JSValue, StrongOptional, VirtualMachineSqlExt as _}; +use crate::jsc::{ + CallFrame, JSGlobalObject, JSValue, JsResult, StrongOptional, VirtualMachineSqlExt as _, +}; #[repr(C)] #[derive(Default)] @@ -18,7 +20,7 @@ impl PostgresSQLContext { // The #[bun_jsc::host_fn] attribute emits the callconv(jsc.conv) shim; the // `export = "..."` arg gives it the #[unsafe(no_mangle)] symbol name. // TODO(port): bun_jsc::host_fn proc-macro (#[bun_jsc::host_fn(export = "PostgresSQLContext__init")]) - pub fn init(global: &JSGlobalObject, frame: &CallFrame) -> JSValue { + pub fn init(global: &JSGlobalObject, frame: &CallFrame) -> JsResult { // `bun_vm()` → `&'static VirtualMachine` (per-thread singleton); // `as_mut()` is the canonical safe escape hatch for the shrinking set // of `&mut self` helpers like `sql_state()` — one audited unsafe lives @@ -26,7 +28,7 @@ impl PostgresSQLContext { let ctx = &mut global.bun_vm().as_mut().sql_state().postgresql_context; ctx.on_query_resolve_fn.set(global, frame.argument(0)); ctx.on_query_reject_fn.set(global, frame.argument(1)); - JSValue::UNDEFINED + Ok(JSValue::UNDEFINED) } } diff --git a/src/sql_jsc/postgres/PostgresSQLQuery.rs b/src/sql_jsc/postgres/PostgresSQLQuery.rs index eadefe3c368..fac36a5f100 100644 --- a/src/sql_jsc/postgres/PostgresSQLQuery.rs +++ b/src/sql_jsc/postgres/PostgresSQLQuery.rs @@ -17,6 +17,7 @@ use super::error_jsc::postgres_error_to_js; use super::postgres_request as PostgresRequest; use super::postgres_sql_connection; use super::postgres_sql_statement::Status as StatementStatus; +use crate::shared::query_args; use bun_sql::postgres::CommandTag; use bun_sql::postgres::PostgresProtocol as protocol; use bun_sql::postgres::any_postgres_error::AnyPostgresError; @@ -382,44 +383,7 @@ impl PostgresSQLQuery { // comptime { @export(&jsc.toJSHostFn(call), .{ .name = "PostgresSQLQuery__createInstance" }); } // TODO(port): proc-macro emits the PostgresSQLQuery__createInstance export; verify codegen name. pub fn call(global_this: &JSGlobalObject, callframe: &CallFrame) -> JsResult { - let arguments = callframe.arguments(); - let mut args = - crate::jsc::call_frame::ArgumentsSlice::init(global_this.bun_vm(), arguments); - // ArgumentsSlice has Drop. - let Some(query) = args.next_eat() else { - return Err(global_this.throw(format_args!("query must be a string"))); - }; - let Some(values) = args.next_eat() else { - return Err(global_this.throw(format_args!("values must be an array"))); - }; - - if !query.is_string() { - return Err(global_this.throw(format_args!("query must be a string"))); - } - - if values.js_type() != crate::jsc::JSType::Array { - return Err(global_this.throw(format_args!("values must be an array"))); - } - - let pending_value: JSValue = args.next_eat().unwrap_or(JSValue::UNDEFINED); - let columns: JSValue = args.next_eat().unwrap_or(JSValue::UNDEFINED); - let js_bigint: JSValue = args.next_eat().unwrap_or(JSValue::FALSE); - let js_simple: JSValue = args.next_eat().unwrap_or(JSValue::FALSE); - - let bigint = js_bigint.is_boolean() && js_bigint.as_boolean(); - let simple = js_simple.is_boolean() && js_simple.as_boolean(); - if simple { - if values.get_length(global_this)? > 0 { - return Err(global_this - .throw_invalid_arguments(format_args!("simple query cannot have parameters"))); - } - if query.get_length(global_this)? >= i32::MAX as u64 { - return Err(global_this.throw_invalid_arguments(format_args!("query is too long"))); - } - } - if !pending_value.js_type().is_array_like() { - return Err(global_this.throw_invalid_argument_type("query", "pendingValue", "Array")); - } + let args = query_args::parse(global_this, callframe)?; let ptr = bun_core::heap::into_raw(Box::new(PostgresSQLQuery::default())); @@ -434,30 +398,24 @@ impl PostgresSQLQuery { // `default()`-initialised by `Box::new` above, so just overwrite the // three non-default fields in place. unsafe { - (*ptr).query = query.to_bun_string(global_this)?; + (*ptr).query = args.query.to_bun_string(global_this)?; (*ptr).this_value.set(JsRef::init_weak(this_value)); (*ptr).flags.set(Flags { - bigint, - simple, + bigint: args.bigint, + simple: args.simple, ..Default::default() }); } - js::binding_set_cached(this_value, global_this, values); - js::pending_value_set_cached(this_value, global_this, pending_value); - if !columns.is_undefined() { - js::columns_set_cached(this_value, global_this, columns); + js::binding_set_cached(this_value, global_this, args.values); + js::pending_value_set_cached(this_value, global_this, args.pending_value); + if !args.columns.is_undefined() { + js::columns_set_cached(this_value, global_this, args.columns); } Ok(this_value) } - pub fn push(&self, global_this: &JSGlobalObject, value: JSValue) { - // TODO(port): Zig source references `this.pending_value` which is not a field on this - // struct — likely dead/broken code in the original. Preserved as a no-op. - let _ = (global_this, value); - } - pub fn do_done( this: &Self, global_object: &JSGlobalObject, @@ -560,6 +518,28 @@ impl PostgresSQLQuery { let writer = connection.writer(); // We need a strong reference to the query so that it doesn't get GC'd this.ref_(); + + // Shared cleanup for the write-failure paths below: release the statement + // ref this query holds (may free a sole-owner statement) and undo the + // speculative `this.ref_()` just taken. + let release_query_ref = || { + this.release_statement(); + // SAFETY: undoes the speculative `this.ref_()` above; count was ≥2, never frees here. + unsafe { Self::deref(this_ptr) }; + }; + // Shared error tail: throw `err` as a postgres error unless an exception + // is already pending. + let throw_write_error = |msg: &[u8], err: AnyPostgresError| -> JsError { + if !global_object.has_exception() { + return global_object.throw_value(postgres_error_to_js( + global_object, + Some(msg), + err, + )); + } + JsError::Thrown + }; + if this.flags.get().simple { bun_core::scoped_log!(Postgres, "executeQuery"); @@ -579,26 +559,10 @@ impl PostgresSQLQuery { let can_execute = !connection.has_query_running(); if can_execute { if let Err(err) = PostgresRequest::execute_query(query_str.slice(), writer) { - // fail to run do cleanup — sole owner just created above - // (rc=1); `release_statement` decrements → 0 frees. - this.release_statement(); - // SAFETY: undoes the speculative `this.ref_()` above; count was ≥2, never frees here. - unsafe { Self::deref(this_ptr) }; - - if !global_object.has_exception() { - return Err(global_object.throw_value(postgres_error_to_js( - global_object, - Some(b"failed to execute query"), - err, - ))); - } - return Err(JsError::Thrown); - } - { - let mut f = connection.flags.get(); - f.set(ConnectionFlags::IS_READY_FOR_QUERY, false); - connection.flags.set(f); + release_query_ref(); + return Err(throw_write_error(b"failed to execute query", err)); } + connection.update_flags(|f| f.set(ConnectionFlags::IS_READY_FOR_QUERY, false)); connection .nonpipelinable_requests .set(connection.nonpipelinable_requests.get() + 1); @@ -611,12 +575,7 @@ impl PostgresSQLQuery { .with_mut(|q| q.write_item(this_ptr)) .is_err() { - // fail to run do cleanup — sole owner just created above - // (rc=1); `release_statement` decrements → 0 frees. - this.release_statement(); - // SAFETY: undoes the speculative `this.ref_()` above; count was ≥2, never frees here. - unsafe { Self::deref(this_ptr) }; - + release_query_ref(); return Err(global_object.throw_out_of_memory()); } @@ -689,6 +648,7 @@ impl PostgresSQLQuery { Ok(v) => v, Err(err) => { drop(signature); + release_query_ref(); return Err( global_object.throw_error(err.into(), "failed to allocate statement") ); @@ -729,27 +689,15 @@ impl PostgresSQLQuery { columns_value, writer, ) { - // fail to run do cleanup — drop the ref we took above. - this.release_statement(); - // SAFETY: undoes the speculative `this.ref_()` above; count was ≥2, never frees here. - unsafe { Self::deref(this_ptr) }; - - if !global_object.has_exception() { - return Err(global_object.throw_value( - postgres_error_to_js( - global_object, - Some(b"failed to bind and execute query"), - err, - ), - )); - } - return Err(JsError::Thrown); - } - { - let mut f = connection.flags.get(); - f.set(ConnectionFlags::IS_READY_FOR_QUERY, false); - connection.flags.set(f); + release_query_ref(); + return Err(throw_write_error( + b"failed to bind and execute query", + err, + )); } + connection.update_flags(|f| { + f.set(ConnectionFlags::IS_READY_FOR_QUERY, false) + }); this.status.set(Status::Binding); this.update_flags(|f| f.pipelined = true); connection @@ -765,6 +713,16 @@ impl PostgresSQLQuery { break 'enqueue; } } + // Shared cleanup for the write-failure paths below: remove the + // speculative statements-map entry (present only when using named + // prepared statements). + let remove_statements_entry = || { + if connection_entry_value.is_some() { + let _ = connection + .statements + .with_mut(|m| m.remove(&signature_hash)); + } + }; let can_execute = !connection.has_query_running(); if can_execute { @@ -779,30 +737,15 @@ impl PostgresSQLQuery { writer, &mut signature, ) { - if connection_entry_value.is_some() { - let _ = connection - .statements - .with_mut(|m| m.remove(&signature_hash)); - } + remove_statements_entry(); drop(signature); - this.release_statement(); - // SAFETY: undoes the speculative `this.ref_()` above; count was ≥2, never frees here. - unsafe { Self::deref(this_ptr) }; - if !global_object.has_exception() { - return Err(global_object.throw_value(postgres_error_to_js( - global_object, - Some(b"failed to prepare and query"), - err, - ))); - } - return Err(JsError::Thrown); + release_query_ref(); + return Err(throw_write_error(b"failed to prepare and query", err)); } - { - let mut f = connection.flags.get(); + connection.update_flags(|f| { f.set(ConnectionFlags::IS_READY_FOR_QUERY, false); f.set(ConnectionFlags::WAITING_TO_PREPARE, true); - connection.flags.set(f); - } + }); this.status.set(Status::Binding); did_write = true; } else if !connection @@ -820,46 +763,25 @@ impl PostgresSQLQuery { &signature.fields, writer, ) { - if connection_entry_value.is_some() { - let _ = connection - .statements - .with_mut(|m| m.remove(&signature_hash)); - } + remove_statements_entry(); drop(signature); - this.release_statement(); - // SAFETY: undoes the speculative `this.ref_()` above; count was ≥2, never frees here. - unsafe { Self::deref(this_ptr) }; - if !global_object.has_exception() { - return Err(global_object.throw_value(postgres_error_to_js( - global_object, - Some(b"failed to write query"), - err, - ))); - } - return Err(JsError::Thrown); + release_query_ref(); + return Err(throw_write_error(b"failed to write query", err)); } if let Err(err) = writer.write(&protocol::SYNC) { - if connection_entry_value.is_some() { - let _ = connection - .statements - .with_mut(|m| m.remove(&signature_hash)); - } + remove_statements_entry(); drop(signature); - if !global_object.has_exception() { - return Err(global_object.throw_value(postgres_error_to_js( - global_object, - Some(b"failed to flush"), - err, - ))); - } - return Err(JsError::Thrown); + // PORT NOTE: the Zig original omitted this on the Sync + // failure path (alone among the write-error branches), + // leaking the speculative ref. Release it like the + // `write_query` branch above does. + release_query_ref(); + return Err(throw_write_error(b"failed to flush", err)); } - { - let mut f = connection.flags.get(); + connection.update_flags(|f| { f.set(ConnectionFlags::IS_READY_FOR_QUERY, false); f.set(ConnectionFlags::WAITING_TO_PREPARE, true); - connection.flags.set(f); - } + }); did_write = true; } // Unnamed prepared statements with params: skip writeQuery+Sync here. @@ -913,6 +835,7 @@ impl PostgresSQLQuery { .with_mut(|q| q.write_item(this_ptr)) .is_err() { + release_query_ref(); return Err(global_object.throw_out_of_memory()); } this.this_value.with_mut(|r| r.upgrade(global_object)); diff --git a/src/sql_jsc/postgres/PostgresSQLStatement.rs b/src/sql_jsc/postgres/PostgresSQLStatement.rs index 162733befb8..d534b025f7d 100644 --- a/src/sql_jsc/postgres/PostgresSQLStatement.rs +++ b/src/sql_jsc/postgres/PostgresSQLStatement.rs @@ -1,7 +1,6 @@ use core::cell::Cell; use crate::jsc::{JSGlobalObject, JSValue, JsResult}; -use bun_collections::StringHashMap; use crate::postgres::error_jsc::postgres_error_to_js; use crate::postgres::signature::Signature; @@ -11,7 +10,6 @@ use crate::shared::sql_data_cell::Flags as DataCellFlags; use bun_sql::postgres::any_postgres_error::AnyPostgresError; use bun_sql::postgres::postgres_protocol as protocol; use bun_sql::postgres::postgres_types::int4; -use bun_sql::shared::ColumnIdentifier; bun_core::declare_scope!(Postgres, visible); @@ -103,52 +101,9 @@ impl PostgresSQLStatement { } self.needs_duplicate_check = false; - let mut seen_numbers: Vec = Vec::new(); - let mut seen_fields: StringHashMap<()> = StringHashMap::default(); - seen_fields.reserve(self.fields.len()); - - // iterate backwards - let mut remaining = self.fields.len(); - let mut flags = DataCellFlags::default(); - while remaining > 0 { - remaining -= 1; - let field: &mut protocol::FieldDescription = &mut self.fields[remaining]; - match &field.name_or_index { - ColumnIdentifier::Name(name) => { - // PORT NOTE: reshaped for borrowck — compute `found_existing` - // before mutating `field.name_or_index`. - // TODO(port): Zig `getOrPut` keys on the borrowed slice; - // StringHashMap clones to an owned `Box<[u8]>` key. Fine for - // a transient dedup set; revisit if profiling flags it. - let found_existing = seen_fields - .get_or_put(name.slice()) - .expect("OOM") - .found_existing; - if found_existing { - field.name_or_index = ColumnIdentifier::Duplicate; - flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); - } - - flags.insert(DataCellFlags::HAS_NAMED_COLUMNS); - } - ColumnIdentifier::Index(index) => { - let index = *index; - if seen_numbers.contains(&index) { - field.name_or_index = ColumnIdentifier::Duplicate; - flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); - } else { - seen_numbers.push(index); - } - - flags.insert(DataCellFlags::HAS_INDEXED_COLUMNS); - } - ColumnIdentifier::Duplicate => { - flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); - } - } - } - - self.fields_flags = flags; + self.fields_flags = crate::shared::cached_structure::mark_duplicate_columns( + self.fields.iter_mut().map(|f| &mut f.name_or_index), + ); } // PORT NOTE: Zig returns `CachedStructure` by value (struct copy). Returning diff --git a/src/sql_jsc/postgres/Signature.rs b/src/sql_jsc/postgres/Signature.rs index c49f898b6d3..21c169ac0b2 100644 --- a/src/sql_jsc/postgres/Signature.rs +++ b/src/sql_jsc/postgres/Signature.rs @@ -23,24 +23,6 @@ impl Signature { // `bun.default_allocator.free`. With `Box<[T]>` fields, Rust's `Drop` // handles this automatically — no explicit `Drop` impl needed. - pub fn hash(&self) -> u64 { - // PORT NOTE: Zig `std.hash.Wyhash.init(0)` + `update` + `final`. The - // `bun_wyhash` crate exposes the streaming API as `Wyhash11` (and a - // stateless `hash`); for now use the one-shot `bun_wyhash::hash` over - // a concatenated byte view. - // `Int4` (= u32) is `NoUninit`; safe `&[u32]` → `&[u8]` view (matches - // Zig `std.mem.sliceAsBytes`). - let fields_bytes: &[u8] = bun_core::cast_slice(&self.fields[..]); - // PERF(port): Zig fed two slices into a streaming Wyhash; bun_wyhash - // currently lacks the std-compatible streaming `Wyhash` type. Concatenate - // into a temp Vec until `bun_wyhash::Wyhash` (streaming, seed-0) lands. - // TODO(port): bun_wyhash::Wyhash (streaming std-compatible API) - let mut buf: Vec = Vec::with_capacity(self.name.len() + fields_bytes.len()); - buf.extend_from_slice(&self.name); - buf.extend_from_slice(fields_bytes); - bun_wyhash::hash(&buf) - } - // TODO(port): narrow error set — Zig inferred set mixes JSError (from // QueryBindingIterator / Tag::from_js), OOM, and error.InvalidQueryBinding. pub fn generate( diff --git a/src/sql_jsc/postgres/protocol/error_response_jsc.rs b/src/sql_jsc/postgres/protocol/error_response_jsc.rs index 60e4d07f862..7a4359ea894 100644 --- a/src/sql_jsc/postgres/protocol/error_response_jsc.rs +++ b/src/sql_jsc/postgres/protocol/error_response_jsc.rs @@ -19,44 +19,40 @@ pub(crate) fn to_js(this: &ErrorResponse, global_object: &JSGlobalObject) -> JSV } let _ = b.allocate(); - let mut severity: &String = &String::DEAD; + fn maybe_slice(s: &String) -> Option<&[u8]> { + if s.is_empty() { + None + } else { + Some(s.byte_slice()) + } + } + let mut code: &String = &String::DEAD; let mut message: &String = &String::DEAD; let mut detail: &String = &String::DEAD; let mut hint: &String = &String::DEAD; - let mut position: &String = &String::DEAD; - let mut internal_position: &String = &String::DEAD; - let mut internal: &String = &String::DEAD; - let mut where_: &String = &String::DEAD; - let mut schema: &String = &String::DEAD; - let mut table: &String = &String::DEAD; - let mut column: &String = &String::DEAD; - let mut datatype: &String = &String::DEAD; - let mut constraint: &String = &String::DEAD; - let mut file: &String = &String::DEAD; - let mut line: &String = &String::DEAD; - let mut routine: &String = &String::DEAD; + let mut opts = PostgresErrorOptions::default(); for msg in this.messages.iter() { match msg { - FieldMessage::Severity(str) => severity = str, + FieldMessage::Severity(str) => opts.severity = maybe_slice(str), FieldMessage::Code(str) => code = str, FieldMessage::Message(str) => message = str, FieldMessage::Detail(str) => detail = str, FieldMessage::Hint(str) => hint = str, - FieldMessage::Position(str) => position = str, - FieldMessage::InternalPosition(str) => internal_position = str, - FieldMessage::Internal(str) => internal = str, - FieldMessage::Where(str) => where_ = str, - FieldMessage::Schema(str) => schema = str, - FieldMessage::Table(str) => table = str, - FieldMessage::Column(str) => column = str, - FieldMessage::Datatype(str) => datatype = str, - FieldMessage::Constraint(str) => constraint = str, - FieldMessage::File(str) => file = str, - FieldMessage::Line(str) => line = str, - FieldMessage::Routine(str) => routine = str, - _ => {} + FieldMessage::Position(str) => opts.position = maybe_slice(str), + FieldMessage::InternalPosition(str) => opts.internal_position = maybe_slice(str), + FieldMessage::Internal(str) => opts.internal_query = maybe_slice(str), + FieldMessage::Where(str) => opts.r#where = maybe_slice(str), + FieldMessage::Schema(str) => opts.schema = maybe_slice(str), + FieldMessage::Table(str) => opts.table = maybe_slice(str), + FieldMessage::Column(str) => opts.column = maybe_slice(str), + FieldMessage::Datatype(str) => opts.data_type = maybe_slice(str), + FieldMessage::Constraint(str) => opts.constraint = maybe_slice(str), + FieldMessage::File(str) => opts.file = maybe_slice(str), + FieldMessage::Line(str) => opts.line = maybe_slice(str), + FieldMessage::Routine(str) => opts.routine = maybe_slice(str), + FieldMessage::LocalizedSeverity(_) => {} } } @@ -91,38 +87,16 @@ pub(crate) fn to_js(this: &ErrorResponse, global_object: &JSGlobalObject) -> JSV } let _ = needs_newline; - fn maybe_slice(s: &String) -> Option<&[u8]> { - if s.is_empty() { - None - } else { - Some(s.byte_slice()) - } - } - - let errno = maybe_slice(code); + opts.errno = maybe_slice(code); + opts.detail = maybe_slice(detail); + opts.hint = maybe_slice(hint); // syntax error - https://www.postgresql.org/docs/8.1/errcodes-appendix.html - let error_code: &'static [u8] = if code.eql_utf8(b"42601") { + opts.code = if code.eql_utf8(b"42601") { b"ERR_POSTGRES_SYNTAX_ERROR" } else { b"ERR_POSTGRES_SERVER_ERROR" }; - let detail_slice = maybe_slice(detail); - let hint_slice = maybe_slice(hint); - let severity_slice = maybe_slice(severity); - let position_slice = maybe_slice(position); - let internal_position_slice = maybe_slice(internal_position); - let internal_query_slice = maybe_slice(internal); - let where_slice = maybe_slice(where_); - let schema_slice = maybe_slice(schema); - let table_slice = maybe_slice(table); - let column_slice = maybe_slice(column); - let data_type_slice = maybe_slice(datatype); - let constraint_slice = maybe_slice(constraint); - let file_slice = maybe_slice(file); - let line_slice = maybe_slice(line); - let routine_slice = maybe_slice(routine); - // PORT NOTE: reshaped for borrowck — `b.allocated_slice()` borrows `b` // mutably; capture `b.len` first. let len = b.len; @@ -132,30 +106,8 @@ pub(crate) fn to_js(this: &ErrorResponse, global_object: &JSGlobalObject) -> JSV b"" }; - create_postgres_error( - global_object, - error_message, - &PostgresErrorOptions { - code: error_code, - errno, - detail: detail_slice, - hint: hint_slice, - severity: severity_slice, - position: position_slice, - internal_position: internal_position_slice, - internal_query: internal_query_slice, - r#where: where_slice, - schema: schema_slice, - table: table_slice, - column: column_slice, - data_type: data_type_slice, - constraint: constraint_slice, - file: file_slice, - line: line_slice, - routine: routine_slice, - }, - ) - .unwrap_or_else(|e| global_object.take_error(e)) + create_postgres_error(global_object, error_message, &opts) + .unwrap_or_else(|e| global_object.take_error(e)) } // ported from: src/sql_jsc/postgres/protocol/error_response_jsc.zig diff --git a/src/sql_jsc/postgres/types/PostgresString.rs b/src/sql_jsc/postgres/types/PostgresString.rs deleted file mode 100644 index bf178b5f4a6..00000000000 --- a/src/sql_jsc/postgres/types/PostgresString.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::jsc::{JSGlobalObject, JSValue, StringJsc as _, js_error_to_postgres}; -use bun_sql::postgres::AnyPostgresError; -use bun_sql::shared::Data; - -/// "no impl" compile error. -pub trait ToJsWithType { - fn to_js_with_type(self, global: &JSGlobalObject) -> Result; -} - -// Covers Zig arms `[:0]u8, []u8, []const u8, [:0]const u8` — all collapse to a byte slice. -impl ToJsWithType for &[u8] { - fn to_js_with_type(self, global: &JSGlobalObject) -> Result { - let str = bun_core::String::borrow_utf8(self); - // `defer str.deinit()` → Drop on bun_core::String - str.to_js(global).map_err(js_error_to_postgres) - } -} - -impl ToJsWithType for bun_core::String { - fn to_js_with_type(self, global: &JSGlobalObject) -> Result { - self.to_js(global).map_err(js_error_to_postgres) - } -} - -impl ToJsWithType for &mut Data { - fn to_js_with_type(self, global: &JSGlobalObject) -> Result { - let str = bun_core::String::borrow_utf8(self.slice()); - // `defer str.deinit()` → Drop on bun_core::String - // TODO(port): Zig calls `value.deinit()` here (consumes the Data). In Rust, Data's - // Drop should handle this at the caller's scope; revisit ownership if Data must be - // freed before this fn returns. - str.to_js(global).map_err(js_error_to_postgres) - } -} - -// ported from: src/sql_jsc/postgres/types/PostgresString.zig diff --git a/src/sql_jsc/postgres/types/bool.rs b/src/sql_jsc/postgres/types/bool.rs deleted file mode 100644 index 29bfa6840da..00000000000 --- a/src/sql_jsc/postgres/types/bool.rs +++ /dev/null @@ -1 +0,0 @@ -// ported from: src/sql_jsc/postgres/types/bool.zig diff --git a/src/sql_jsc/postgres/types/bytea.rs b/src/sql_jsc/postgres/types/bytea.rs deleted file mode 100644 index e6cf20fd7b7..00000000000 --- a/src/sql_jsc/postgres/types/bytea.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::jsc::{ArrayBuffer, JSGlobalObject, JSValue, js_error_to_postgres}; -use bun_sql::postgres::AnyPostgresError; -use bun_sql::shared::Data; - -// PostgresString.rs. -pub trait ByteaToJs { - fn bytea_to_js(self, global: &JSGlobalObject) -> Result; -} - -// PORT NOTE: reshaped `value: *Data` + `defer value.deinit()` → owned `Data`; -// Drop at scope exit replaces the explicit deinit. -impl ByteaToJs for Data { - fn bytea_to_js(self, global: &JSGlobalObject) -> Result { - // var slice = value.slice()[@min(1, value.len)..]; - // _ = slice; - // - // Zig's `JSValue.createBuffer(global, slice, null)` with a null - // allocator maps to the copying Buffer constructor: `self.slice()` - // borrows a transient decode buffer that `Drop` frees on return, so - // JSC must own its own copy. - ArrayBuffer::create_buffer(global, self.slice()).map_err(js_error_to_postgres) - } -} - -// ported from: src/sql_jsc/postgres/types/bytea.zig diff --git a/src/sql_jsc/postgres/types/date.rs b/src/sql_jsc/postgres/types/date.rs index 3f43fb6d674..6bb45277cff 100644 --- a/src/sql_jsc/postgres/types/date.rs +++ b/src/sql_jsc/postgres/types/date.rs @@ -1,6 +1,5 @@ use crate::jsc::{JSGlobalObject, JSValue, JsResult}; use bun_sql::postgres::types::int_types::Short; -use bun_sql::shared::Data; pub const TO: i32 = 1184; pub const FROM: [Short; 3] = [1082, 1114, 1184]; @@ -62,42 +61,4 @@ pub fn from_js(global_object: &JSGlobalObject, value: JSValue) -> JsResult Ok((unix_timestamp - POSTGRES_EPOCH_DATE) * US_PER_MS) } -// Zig `toJS(value: anytype)` dispatches on `@TypeOf(value)` at comptime over a -// closed set {i64, *Data}. Rust has no comptime type-switch; modeled as a trait -// with per-type impls so `tag_jsc::to_js_with_type` can dispatch uniformly. The -// `else => @compileError(...)` arm is the natural "no impl" compile error. -pub trait DateToJs { - fn date_to_js(self, global_object: &JSGlobalObject) -> JSValue; -} - -impl DateToJs for i64 { - fn date_to_js(self, global_object: &JSGlobalObject) -> JSValue { - to_js_i64(global_object, self) - } -} - -impl DateToJs for Data { - fn date_to_js(self, global_object: &JSGlobalObject) -> JSValue { - to_js_data(global_object, &self) - } -} - -pub fn to_js(global_object: &JSGlobalObject, value: T) -> JSValue { - value.date_to_js(global_object) -} - -pub fn to_js_i64(global_object: &JSGlobalObject, value: i64) -> JSValue { - // Convert from Postgres timestamp (μs since 2000-01-01) to Unix timestamp (ms) - let ms = value.div_euclid(US_PER_MS) + POSTGRES_EPOCH_DATE; - JSValue::from_date_number(global_object, ms as f64) -} - -pub fn to_js_data(global_object: &JSGlobalObject, value: &Data) -> JSValue { - let z = value.slice_z(); - // SAFETY: ZStr invariant guarantees a readable NUL terminator at `len`; Postgres - // date payloads contain no interior NULs, satisfying CStr's contract. - let cstr = unsafe { bun_core::ffi::cstr(z.as_ptr()) }; - JSValue::from_date_string(global_object, cstr) -} - // ported from: src/sql_jsc/postgres/types/date.zig diff --git a/src/sql_jsc/postgres/types/json.rs b/src/sql_jsc/postgres/types/json.rs deleted file mode 100644 index 67954bb100e..00000000000 --- a/src/sql_jsc/postgres/types/json.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::jsc::{JSGlobalObject, JSValue, StringJsc as _, js_error_to_postgres}; -use bun_sql::postgres::AnyPostgresError; -use bun_sql::shared::Data; - -// bytea.rs. -pub trait JsonToJs { - fn json_to_js(self, global: &JSGlobalObject) -> Result; -} - -// PORT NOTE: reshaped `value: *Data` + `defer value.deinit()` → owned `Data`; -// Drop at scope exit replaces the explicit deinit. -impl JsonToJs for Data { - fn json_to_js(self, global: &JSGlobalObject) -> Result { - let str = bun_core::String::borrow_utf8(self.slice()); - // `defer str.deref()` — handled by Drop on bun_core::String. - let js_str = str.to_js(global).map_err(js_error_to_postgres)?; - let parse_result = js_str.parse_json(global).map_err(js_error_to_postgres)?; - // PORT NOTE: Zig `parse_result.AnyPostgresError()` is a typo for - // `.isAnyError()` (verified against bun_jsc surface — no `AnyPostgresError` - // method exists on JSValue). - if parse_result.is_any_error() { - return Err(js_error_to_postgres(global.throw_value(parse_result))); - } - - Ok(parse_result) - } -} - -// ported from: src/sql_jsc/postgres/types/json.zig diff --git a/src/sql_jsc/postgres/types/tag_jsc.rs b/src/sql_jsc/postgres/types/tag_jsc.rs index 7ed044a9026..e4606337362 100644 --- a/src/sql_jsc/postgres/types/tag_jsc.rs +++ b/src/sql_jsc/postgres/types/tag_jsc.rs @@ -3,9 +3,7 @@ //! conversion paths live here. use crate::jsc::{JSGlobalObject, JSType, JSValue, JsResult}; -use bun_sql::postgres::AnyPostgresError; use bun_sql::postgres::types::tag::Tag; -use bun_sql::shared::Data; // `comptime T: Tag` → const generic per PORTING.md. `Tag` in the Rust port is a // `#[repr(transparent)] struct Tag(Short)` with associated consts (non-exhaustive @@ -23,24 +21,6 @@ pub(crate) fn to_js_typed_array_type(t: Tag) -> Result } } -/// rest may `unreachable!()` (mirroring Zig's per-monomorphization compile -/// error becoming a runtime impossibility once the `tag` is fixed). -pub trait TagToJs: Sized { - /// `.numeric | .float4 | .float8 | .int4` arms → `JSValue.jsNumber(value)`. - fn as_js_number(self) -> f64; - /// `.int8` arm → `JSValue.fromInt64NoTruncate(global, value)`. - fn as_i64(self) -> i64; - /// `.bool` arm → `bool.toJS(global, value)`. - fn as_bool(self) -> bool; - /// `.json | .jsonb | .bytea` arms → `json.toJS` / `bytea.toJS`, both of - /// which take owned `Data` in the Rust port. - fn into_data(self) -> Data; - /// `.timestamp | .timestamptz` arm → `date.toJS(global, value)`. - fn date_to_js(self, global: &JSGlobalObject) -> JSValue; - /// `else` arm → `string.toJS(global, value)`. - fn string_to_js(self, global: &JSGlobalObject) -> Result; -} - pub fn from_js(global: &JSGlobalObject, value: JSValue) -> JsResult { if value.is_empty_or_undefined_or_null() { return Ok(Tag::numeric); diff --git a/src/sql_jsc/shared/CachedStructure.rs b/src/sql_jsc/shared/CachedStructure.rs index dd1da43410c..11e7b7291f4 100644 --- a/src/sql_jsc/shared/CachedStructure.rs +++ b/src/sql_jsc/shared/CachedStructure.rs @@ -1,8 +1,66 @@ use core::mem::{ManuallyDrop, MaybeUninit}; use crate::jsc::{ExternColumnIdentifier, JSGlobalObject, JSObject, JSValue, StrongOptional}; +use crate::shared::sql_data_cell::Flags as DataCellFlags; +use bun_collections::StringHashMap; +use bun_core::UnwrapOrOom; use bun_sql::shared::ColumnIdentifier; +/// Shared body of `{Postgres,MySQL}SQLStatement::check_for_duplicate_fields()`. +/// +/// Scans the columns backwards, retagging repeated names/indices as +/// [`ColumnIdentifier::Duplicate`], and returns the accumulated data-cell +/// flags. Run this before [`CachedStructure::build_from_columns`] so the +/// `Duplicate` tags it skips are present. +pub fn mark_duplicate_columns<'a, I>(columns: I) -> DataCellFlags +where + I: DoubleEndedIterator + ExactSizeIterator, +{ + let mut seen_numbers: Vec = Vec::new(); + // PERF(port): Zig `getOrPut` keyed on the borrowed `name.slice()`; + // StringHashMap clones to an owned `Box<[u8]>` key. Fine for a transient + // dedup set — profile if it shows up on a hot path. + let mut seen_fields: StringHashMap<()> = StringHashMap::default(); + seen_fields.reserve(columns.len()); + + let mut flags = DataCellFlags::default(); + // iterate backwards + for name_or_index in columns.rev() { + match &*name_or_index { + ColumnIdentifier::Name(name) => { + // PORT NOTE: reshaped for borrowck — compute `found_existing` + // before mutating `*name_or_index`. + let found_existing = seen_fields + .get_or_put(name.slice()) + .unwrap_or_oom() + .found_existing; + if found_existing { + *name_or_index = ColumnIdentifier::Duplicate; + flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); + } + + flags.insert(DataCellFlags::HAS_NAMED_COLUMNS); + } + ColumnIdentifier::Index(index) => { + let index = *index; + if seen_numbers.contains(&index) { + *name_or_index = ColumnIdentifier::Duplicate; + flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); + } else { + seen_numbers.push(index); + } + + flags.insert(DataCellFlags::HAS_INDEXED_COLUMNS); + } + ColumnIdentifier::Duplicate => { + flags.insert(DataCellFlags::HAS_DUPLICATE_COLUMNS); + } + } + } + + flags +} + #[derive(Default)] pub struct CachedStructure { pub structure: StrongOptional, // Strong.Optional = .empty diff --git a/src/sql_jsc/shared/QueryBindingIterator.rs b/src/sql_jsc/shared/QueryBindingIterator.rs index bca1525b528..cb65a9e16a0 100644 --- a/src/sql_jsc/shared/QueryBindingIterator.rs +++ b/src/sql_jsc/shared/QueryBindingIterator.rs @@ -54,19 +54,6 @@ impl<'a> QueryBindingIterator<'a> { } } } - - pub fn reset(&mut self) { - match self { - Self::Array(iter) => { - iter.i = 0; - } - Self::Objects(iter) => { - iter.cell_i = 0; - iter.row_i = 0; - iter.current_row = JSValue::ZERO; - } - } - } } // ported from: src/sql_jsc/shared/QueryBindingIterator.zig diff --git a/src/sql_jsc/shared/SQLDataCell.rs b/src/sql_jsc/shared/SQLDataCell.rs index fb74a8f908e..0147f19844a 100644 --- a/src/sql_jsc/shared/SQLDataCell.rs +++ b/src/sql_jsc/shared/SQLDataCell.rs @@ -264,8 +264,117 @@ impl SQLDataCell { } } - pub fn raw<'a>(optional_bytes: impl IntoOptionalData<'a>) -> SQLDataCell { - if let Some(bytes) = optional_bytes.into_optional_data() { + #[inline] + pub fn null() -> SQLDataCell { + SQLDataCell::default() + } + + #[inline] + pub fn int4(value: i32) -> SQLDataCell { + SQLDataCell { + tag: Tag::Int4, + value: Value { int4: value }, + ..Default::default() + } + } + + #[inline] + pub fn uint4(value: u32) -> SQLDataCell { + SQLDataCell { + tag: Tag::Uint4, + value: Value { uint4: value }, + ..Default::default() + } + } + + #[inline] + pub fn int8(value: i64) -> SQLDataCell { + SQLDataCell { + tag: Tag::Int8, + value: Value { int8: value }, + ..Default::default() + } + } + + #[inline] + pub fn uint8(value: u64) -> SQLDataCell { + SQLDataCell { + tag: Tag::Uint8, + value: Value { uint8: value }, + ..Default::default() + } + } + + #[inline] + pub fn float8(value: f64) -> SQLDataCell { + SQLDataCell { + tag: Tag::Float8, + value: Value { float8: value }, + ..Default::default() + } + } + + #[inline] + pub fn bool_(value: bool) -> SQLDataCell { + SQLDataCell { + tag: Tag::Bool, + value: Value { bool_: value as u8 }, + ..Default::default() + } + } + + #[inline] + pub fn date(value: f64) -> SQLDataCell { + SQLDataCell { + tag: Tag::Date, + value: Value { date: value }, + ..Default::default() + } + } + + #[inline] + pub fn date_with_tz(value: f64) -> SQLDataCell { + SQLDataCell { + tag: Tag::DateWithTimeZone, + value: Value { + date_with_time_zone: value, + }, + ..Default::default() + } + } + + /// Owned string cell: clones `bytes` into a WTFStringImpl, freed via + /// `free_value = 1`. Empty input becomes a null pointer, which the C++ + /// side (SQLClient.cpp) renders as the empty string. + #[inline] + pub fn string(bytes: &[u8]) -> SQLDataCell { + SQLDataCell { + tag: Tag::String, + value: Value { + string: clone_utf8_or_null(bytes), + }, + free_value: 1, + ..Default::default() + } + } + + /// Owned JSON cell: clones `bytes` into a WTFStringImpl, freed via + /// `free_value = 1`. Empty input becomes a null pointer, which the C++ + /// side (SQLClient.cpp) renders as `null`. + #[inline] + pub fn json(bytes: &[u8]) -> SQLDataCell { + SQLDataCell { + tag: Tag::Json, + value: Value { + json: clone_utf8_or_null(bytes), + }, + free_value: 1, + ..Default::default() + } + } + + pub fn raw(optional_bytes: Option<&Data>) -> SQLDataCell { + if let Some(bytes) = optional_bytes { let bytes_slice = bytes.slice(); return SQLDataCell { tag: Tag::Raw, @@ -279,11 +388,38 @@ impl SQLDataCell { }; } // TODO: check empty and null fields - SQLDataCell { - tag: Tag::Null, - value: Value { null: 0 }, - ..Default::default() - } + SQLDataCell::null() + } + + /// Shared wrapper around `construct_object_from_data_cell` used by the + /// per-row `to_js` paths (postgres `Putter`, mysql `Row`): extracts the + /// cached-structure column names and forwards the cells. + pub fn to_js_object( + global_object: &JSGlobalObject, + array: JSValue, + structure: JSValue, + cells: &mut [SQLDataCell], + count: u32, + flags: Flags, + result_mode: u8, + cached_structure: Option<&crate::shared::CachedStructure>, + ) -> JsResult { + let (names, names_count) = match cached_structure.and_then(|c| c.fields.as_deref()) { + Some(f) => (f.as_ptr().cast_mut(), f.len() as u32), + None => (ptr::null_mut(), 0), + }; + + SQLDataCell::construct_object_from_data_cell( + global_object, + array, + structure, + cells.as_mut_ptr(), + count, + flags, + result_mode, + names, + names_count, + ) } // TODO: cppbind isn't yet able to detect slice parameters when the next is uint32_t @@ -331,34 +467,16 @@ impl SQLDataCell { } } -/// Coercion helper mirroring Zig's implicit `*const Data` → `?*const Data` -/// promotion at `raw()` call sites. Lets callers pass `&Data`, `&mut Data`, -/// `Option<&Data>`, or `Option<&mut Data>` without wrapping. -pub trait IntoOptionalData<'a> { - fn into_optional_data(self) -> Option<&'a Data>; -} -impl<'a> IntoOptionalData<'a> for &'a Data { - #[inline] - fn into_optional_data(self) -> Option<&'a Data> { - Some(self) - } -} -impl<'a> IntoOptionalData<'a> for &'a mut Data { - #[inline] - fn into_optional_data(self) -> Option<&'a Data> { - Some(&*self) - } -} -impl<'a> IntoOptionalData<'a> for Option<&'a Data> { - #[inline] - fn into_optional_data(self) -> Option<&'a Data> { - self - } -} -impl<'a> IntoOptionalData<'a> for Option<&'a mut Data> { - #[inline] - fn into_optional_data(self) -> Option<&'a Data> { - self.map(|d| &*d) +/// `bun.String.cloneUTF8(slice).value.WTFStringImpl` in Zig — clones the bytes +/// into a fresh WTFStringImpl whose +1 ref is transferred to the cell +/// (`free_value = 1`). Empty input maps to a null pointer instead of +/// allocating an empty string. +#[inline] +fn clone_utf8_or_null(bytes: &[u8]) -> WTFStringImpl { + if !bytes.is_empty() { + bun_core::String::clone_utf8(bytes).leak_wtf_impl() + } else { + ptr::null_mut() } } diff --git a/src/sql_jsc/shared/connection_args.rs b/src/sql_jsc/shared/connection_args.rs new file mode 100644 index 00000000000..f30b3621c48 --- /dev/null +++ b/src/sql_jsc/shared/connection_args.rs @@ -0,0 +1,259 @@ +//! Shared `createInstance` argument parsing for the SQL connection bindings. +//! +//! `PostgresSQLConnection::call` and `JSMySQLConnection::create_instance` receive the +//! same 15 leading JS arguments; [`parse`] owns the one copy of that decoding — including +//! TLS `SSL_CTX` creation, its cleanup guard for early-return paths, and the null-byte +//! injection check — so the two drivers can't drift. Driver-specific arguments (MySQL's +//! `allowPublicKeyRetrieval`, argument 15) stay at the call sites. +//! +//! [`verify_tls_server`] is the matching post-handshake half: the one copy of the +//! server-certificate / hostname verification both drivers run after the TLS handshake. + +use crate::jsc::api::server_config::SSLConfig; +use crate::jsc::{ + CallFrame, JSGlobalObject, JSValue, JsResult, VirtualMachine, VirtualMachineSqlExt as _, +}; +use bun_core::{OwnedString, ZigStringSlice, strings}; +use bun_uws as uws; + +/// `bun_sql::postgres::SSLMode` and `bun_sql::mysql::ssl_mode::SSLMode` are identical +/// `#[repr(u8)]` enums; this trait lets [`parse`] produce whichever one the caller stores. +pub(crate) trait SslModeArg: Copy + PartialEq { + /// Variants in discriminant order (`Disable = 0` .. `VerifyFull = 4`). + const ALL: [Self; 5]; + + /// `arguments[5]` — out-of-range values fall back to `Disable`. + fn from_i32(value: i32) -> Self { + let [disable, prefer, require, verify_ca, verify_full] = Self::ALL; + match value { + 1 => prefer, + 2 => require, + 3 => verify_ca, + 4 => verify_full, + _ => disable, + } + } + + fn is_disable(self) -> bool { + self == Self::ALL[0] + } +} + +macro_rules! impl_ssl_mode_arg { + ($($t:ty),+ $(,)?) => {$( + impl SslModeArg for $t { + const ALL: [Self; 5] = + [Self::Disable, Self::Prefer, Self::Require, Self::VerifyCa, Self::VerifyFull]; + } + )+}; +} +impl_ssl_mode_arg!( + bun_sql::postgres::SSLMode, + bun_sql::mysql::ssl_mode::SSLMode +); + +type TlsPair = (Option<*mut uws::SslCtx>, SSLConfig); + +/// Owns `(secure, tls_config)` until the caller transfers them into its connection struct +/// via `scopeguard::ScopeGuard::into_inner`. Dropping the guard instead (any early-return +/// path) releases the `SSL_CTX` ref and the config — the Rust spelling of the Zig `errdefer`. +pub(crate) type TlsGuard = scopeguard::ScopeGuard; + +fn release_tls((secure, tls_config): TlsPair) { + if let Some(s) = secure { + // SAFETY: `s` came from `ssl_ctx_cache().get_or_create_opts()`; this guard owns + // the one outstanding ref. + unsafe { bun_boringssl_sys::SSL_CTX_free(s) }; + } + drop(tls_config); +} + +/// Post-TLS-handshake server verification shared by `PostgresSQLConnection::on_handshake` +/// and `MySQLConnection::do_handshake`. +/// +/// The caller has already established that its SSL mode is `VerifyCa` or `VerifyFull` and +/// that `reject_unauthorized` is set; `verify_full` selects the additional hostname +/// identity check. Returns `false` when the connection must be rejected — failure +/// reporting stays at the call site. +pub(crate) fn verify_tls_server( + verify_full: bool, + tls_config: &SSLConfig, + native_handle: Option<*mut core::ffi::c_void>, + error_no: i32, +) -> bool { + if error_no != 0 { + return false; + } + if !verify_full { + return true; + } + // VerifyFull additionally requires the certificate identity to match the intended + // host. Absence of a configured server name is not a license to skip the check — + // fail closed. + let servername = tls_config.server_name(); + if servername.is_null() { + return false; + } + let ssl_ptr: *mut bun_boringssl_sys::SSL = + native_handle.map_or(core::ptr::null_mut(), |p| p.cast()); + // SAFETY: the native handle of a connected TLS socket is `SSL*`, live for the + // duration of the handshake callback. + let Some(ssl) = (unsafe { ssl_ptr.as_mut() }) else { + return false; + }; + // SAFETY: `servername` is a NUL-terminated C string owned by `tls_config` for the + // connection lifetime. + let hostname = unsafe { bun_core::ffi::cstr(servername) }.to_bytes(); + bun_boringssl::check_server_identity(ssl, hostname) +} + +/// A string argument decoded to UTF-8 exactly once, shared by the null-byte +/// check in [`parse`] and by both drivers' connection buffers — so neither +/// caller re-runs the conversion. +pub(crate) struct Utf8Arg { + /// Declared before `_source` so it drops first: on the all-ASCII fast path + /// it is a `ZigStringSlice::Static` borrow of `_source`'s bytes. + utf8: ZigStringSlice, + /// Keeps the backing `WTFStringImpl` alive while `utf8` may borrow its bytes. + _source: OwnedString, +} + +impl Utf8Arg { + fn new(source: OwnedString) -> Self { + let utf8 = source.to_utf8_without_ref(); + Self { + utf8, + _source: source, + } + } + + /// The UTF-8 bytes. + pub(crate) fn slice(&self) -> &[u8] { + self.utf8.slice() + } + + /// Consume into owned bytes — moves the conversion's buffer when it + /// allocated one, copies the borrowed view otherwise. + pub(crate) fn into_boxed_bytes(self) -> Box<[u8]> { + self.utf8.into_vec().into_boxed_slice() + } +} + +/// The 15 parsed `createInstance` arguments common to Postgres and MySQL. +pub(crate) struct ConnectionArgs { + pub hostname: OwnedString, + pub port: i32, + pub username: Utf8Arg, + pub password: Utf8Arg, + pub database: Utf8Arg, + pub options: Utf8Arg, + pub path: Utf8Arg, + pub ssl_mode: Mode, + pub tls: TlsGuard, + pub on_connect: JSValue, + pub on_close: JSValue, + pub idle_timeout: i32, + pub connection_timeout: i32, + pub max_lifetime: i32, + pub use_unnamed_prepared_statements: bool, +} + +/// Decode `arguments[0..=14]`, building the TLS `SSL_CTX` when `sslMode != Disable`. +/// +/// `Ok(None)` means a JS exception is already pending and the caller must return +/// `JSValue::ZERO` (the Zig `return .zero` paths). +pub(crate) fn parse( + vm: &mut VirtualMachine, + global_object: &JSGlobalObject, + callframe: &CallFrame, +) -> JsResult>> { + let arguments = callframe.arguments(); + + let hostname = OwnedString::new(arguments[0].to_bun_string(global_object)?); + let port = arguments[1].coerce::(global_object)?; + let username = Utf8Arg::new(OwnedString::new(arguments[2].to_bun_string(global_object)?)); + let password = Utf8Arg::new(OwnedString::new(arguments[3].to_bun_string(global_object)?)); + let database = Utf8Arg::new(OwnedString::new(arguments[4].to_bun_string(global_object)?)); + let ssl_mode = Mode::from_i32(arguments[5].to_int32()); + + let tls_object = arguments[6]; + let mut tls_config = SSLConfig::default(); + let mut secure: Option<*mut uws::SslCtx> = None; + if !ssl_mode.is_disable() { + tls_config = if tls_object.is_boolean() && tls_object.to_boolean() { + SSLConfig::default() + } else if tls_object.is_object() { + match SSLConfig::from_js(&mut *vm, global_object, tls_object) { + Ok(opt) => opt.unwrap_or_default(), + Err(_) => return Ok(None), + } + } else { + return Err(global_object + .throw_invalid_arguments(format_args!("tls must be a boolean or an object"))); + }; + + if global_object.has_exception() { + drop(tls_config); + return Ok(None); + } + + // We always request the cert so we can verify it, and we manually abort the + // connection if the hostname doesn't match. Built here — not at STARTTLS time — so + // cert/CA errors throw synchronously; the per-VM weak `SSLContextCache` shares one + // `SSL_CTX*` per distinct config across pooled connections / reconnects. + let mut err = uws::create_bun_socket_error_t::none; + secure = vm + .ssl_ctx_cache() + .get_or_create_opts(&tls_config.as_usockets_for_client_verification(), &mut err); + if secure.is_none() { + drop(tls_config); + return Err( + global_object.throw_value(crate::jsc::create_bun_socket_error_to_js( + err, + global_object, + )), + ); + } + } + + // Covers the throwing parses / null-byte checks below and everything the caller does + // until ownership transfers into the connection struct. + let tls: TlsGuard = scopeguard::guard((secure, tls_config), release_tls as fn(TlsPair)); + + let options = Utf8Arg::new(OwnedString::new(arguments[7].to_bun_string(global_object)?)); + let path = Utf8Arg::new(OwnedString::new(arguments[8].to_bun_string(global_object)?)); + + // Reject null bytes in connection parameters to prevent wire-protocol parameter + // injection (null bytes act as field terminators in both the Postgres `key\0value\0` + // startup message and the MySQL handshake). + for (s, name) in [ + (&username, "username"), + (&password, "password"), + (&database, "database"), + (&path, "path"), + ] { + let utf8 = s.slice(); + if !utf8.is_empty() && strings::index_of_char(utf8, 0).is_some() { + return Err(global_object + .throw_invalid_arguments(format_args!("{name} must not contain null bytes"))); + } + } + + Ok(Some(ConnectionArgs { + hostname, + port, + username, + password, + database, + options, + path, + ssl_mode, + tls, + on_connect: arguments[9], + on_close: arguments[10], + idle_timeout: arguments[11].to_int32(), + connection_timeout: arguments[12].to_int32(), + max_lifetime: arguments[13].to_int32(), + use_unnamed_prepared_statements: arguments[14].as_boolean(), + })) +} diff --git a/src/sql_jsc/shared/query_args.rs b/src/sql_jsc/shared/query_args.rs new file mode 100644 index 00000000000..5f16c0efc2d --- /dev/null +++ b/src/sql_jsc/shared/query_args.rs @@ -0,0 +1,72 @@ +//! Shared `createInstance` argument parsing for the SQL query bindings. +//! +//! `PostgresSQLQuery::call` and `JSMySQLQuery::create_instance` receive the same six JS +//! arguments (`query`, `values`, `pendingValue`, `columns`, `bigint`, `simple`); [`parse`] +//! owns the one copy of that decoding and validation — mirroring +//! [`connection_args::parse`](crate::shared::connection_args::parse) for connections — so +//! the two drivers can't drift. + +use crate::jsc::{CallFrame, JSGlobalObject, JSType, JSValue, JsResult}; + +/// The parsed `createInstance` arguments common to Postgres and MySQL. +pub(crate) struct QueryArgs { + /// `arguments[0]` — validated to be a JS string. + pub query: JSValue, + /// `arguments[1]` — validated to be a JS array. + pub values: JSValue, + /// `arguments[2]` — validated to be array-like. + pub pending_value: JSValue, + /// `arguments[3]` — may be `undefined`. + pub columns: JSValue, + pub bigint: bool, + pub simple: bool, +} + +/// Decode and validate `arguments[0..=5]`. +pub(crate) fn parse(global_this: &JSGlobalObject, callframe: &CallFrame) -> JsResult { + let mut arguments = callframe.arguments().iter().copied(); + + let Some(query) = arguments.next() else { + return Err(global_this.throw(format_args!("query must be a string"))); + }; + let Some(values) = arguments.next() else { + return Err(global_this.throw(format_args!("values must be an array"))); + }; + + if !query.is_string() { + return Err(global_this.throw(format_args!("query must be a string"))); + } + + if values.js_type() != JSType::Array { + return Err(global_this.throw(format_args!("values must be an array"))); + } + + let pending_value = arguments.next().unwrap_or(JSValue::UNDEFINED); + let columns = arguments.next().unwrap_or(JSValue::UNDEFINED); + let js_bigint = arguments.next().unwrap_or(JSValue::FALSE); + let js_simple = arguments.next().unwrap_or(JSValue::FALSE); + + let bigint = js_bigint.is_boolean() && js_bigint.as_boolean(); + let simple = js_simple.is_boolean() && js_simple.as_boolean(); + if simple { + if values.get_length(global_this)? > 0 { + return Err(global_this + .throw_invalid_arguments(format_args!("simple query cannot have parameters"))); + } + if query.get_length(global_this)? >= i32::MAX as u64 { + return Err(global_this.throw_invalid_arguments(format_args!("query is too long"))); + } + } + if !pending_value.js_type().is_array_like() { + return Err(global_this.throw_invalid_argument_type("query", "pendingValue", "Array")); + } + + Ok(QueryArgs { + query, + values, + pending_value, + columns, + bigint, + simple, + }) +} diff --git a/test/js/sql/sql-mysql-createinstance-validation.test.ts b/test/js/sql/sql-mysql-createinstance-validation.test.ts new file mode 100644 index 00000000000..31f4cc82eb9 --- /dev/null +++ b/test/js/sql/sql-mysql-createinstance-validation.test.ts @@ -0,0 +1,51 @@ +import { SQL } from "bun"; +import { describe, expect, test } from "bun:test"; + +// Native createInstance argument validation shared between the MySQL and +// Postgres drivers (src/sql_jsc/shared/connection_args.rs / query_args.rs). +// Every error here is thrown by the native parser before any socket is +// created, so no server (and no docker) is needed — the address below is +// never dialed. The Postgres-adapter twin of this suite lives in sql.test.ts. +describe("shared createInstance validation (no server)", () => { + const base = { adapter: "mysql", hostname: "127.0.0.1", port: 1, max: 1 } as const; + + // Connection parameters are written into the NUL-delimited MySQL handshake, + // so a NUL byte would inject extra fields; the native parser must refuse + // them before connecting. + test.concurrent.each(["username", "password", "database"] as const)( + "rejects %s containing null bytes", + async field => { + await using sql = new SQL({ ...base, username: "u", [field]: "a\0b" }); + // `Query` is a lazy thenable, so collect the rejection explicitly. + const err: any = await sql`select 1`.then( + () => null, + e => e, + ); + expect(err?.message).toBe(`${field} must not contain null bytes`); + }, + ); + + test.concurrent("rejects tls that is neither a boolean nor an object", async () => { + // A truthy non-boolean/non-object upgrades sslMode to `require` in JS and + // reaches the native parser as-is. + await using sql = new SQL({ ...base, username: "u", tls: 1 as any }); + const err: any = await sql`select 1`.then( + () => null, + e => e, + ); + expect(err?.message).toBe("tls must be a boolean or an object"); + }); + + test.concurrent("rejects simple queries with parameters", async () => { + await using sql = new SQL({ ...base, username: "u" }); + // Query-handle creation fails before the pool ever connects. + const err: any = await sql + .unsafe("select ?", [1]) + .simple() + .then( + () => null, + e => e, + ); + expect(err?.message).toBe("simple query cannot have parameters"); + }); +}); diff --git a/test/js/sql/sql-mysql-duplicate-columns.test.ts b/test/js/sql/sql-mysql-duplicate-columns.test.ts new file mode 100644 index 00000000000..c0b2368356c --- /dev/null +++ b/test/js/sql/sql-mysql-duplicate-columns.test.ts @@ -0,0 +1,183 @@ +import { SQL } from "bun"; +import { describe, expect, test } from "bun:test"; +import { once } from "events"; +import net from "net"; + +// Duplicate result-set column names, driven through a mock server so the test +// runs without docker. The statement decode path retags every occurrence but +// the last as a duplicate (src/sql_jsc/shared/CachedStructure.rs +// `mark_duplicate_columns`): object rows keep only the last occurrence, while +// `.values()` keeps every cell. +describe("duplicate column names (mock server, no docker)", () => { + const MYSQL_TYPE_VAR_STRING = 0xfd; + + function u16le(n: number): Buffer { + return Buffer.from([n & 0xff, (n >> 8) & 0xff]); + } + function u24le(n: number): Buffer { + return Buffer.from([n & 0xff, (n >> 8) & 0xff, (n >> 16) & 0xff]); + } + function u32le(n: number): Buffer { + return Buffer.from([n & 0xff, (n >> 8) & 0xff, (n >> 16) & 0xff, (n >>> 24) & 0xff]); + } + function packet(seq: number, payload: Buffer): Buffer { + return Buffer.concat([u24le(payload.length), Buffer.from([seq]), payload]); + } + function lenencStr(s: string): Buffer { + const buf = Buffer.from(s, "utf-8"); + if (buf.length >= 0xfb) throw new Error("lenenc: long form not needed for this test"); + return Buffer.concat([Buffer.from([buf.length]), buf]); + } + + const CLIENT_PROTOCOL_41 = 1 << 9; + const CLIENT_SECURE_CONNECTION = 1 << 15; + const CLIENT_PLUGIN_AUTH = 1 << 19; + const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21; + const CLIENT_DEPRECATE_EOF = 1 << 24; + const SERVER_CAPS = + CLIENT_PROTOCOL_41 | + CLIENT_SECURE_CONNECTION | + CLIENT_PLUGIN_AUTH | + CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | + CLIENT_DEPRECATE_EOF; + + function handshakeV10(): Buffer { + const authData1 = Buffer.alloc(8, 0x61); + const authData2 = Buffer.alloc(13, 0x62); + authData2[12] = 0; + return packet( + 0, + Buffer.concat([ + Buffer.from([10]), + Buffer.from("mock-5.7.0\0"), + u32le(1), + authData1, + Buffer.from([0]), + u16le(SERVER_CAPS & 0xffff), + Buffer.from([0x2d]), + u16le(0x0002), + u16le((SERVER_CAPS >>> 16) & 0xffff), + Buffer.from([21]), + Buffer.alloc(10, 0), + authData2, + Buffer.from("mysql_native_password\0"), + ]), + ); + } + function okPacket(seq: number, header = 0x00): Buffer { + return packet(seq, Buffer.from([header, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00])); + } + function varStringColumn(name: string): Buffer { + return Buffer.concat([ + lenencStr("def"), + lenencStr(""), + lenencStr("t"), + lenencStr("t"), + lenencStr(name), + lenencStr(name), + Buffer.from([0x0c]), + u16le(33), // utf8_general_ci + u32le(1024), + Buffer.from([MYSQL_TYPE_VAR_STRING]), + u16le(0), + Buffer.from([0]), + Buffer.from([0, 0]), + ]); + } + + // Both columns are named "a"; the row carries distinct values so the test + // can tell which occurrence the object row kept. + const FIRST = "first"; + const LAST = "last"; + + // Text-protocol result set: count, two identically-named columns, one row. + function textResultSet(startSeq: number): Buffer { + let seq = startSeq; + return Buffer.concat([ + packet(seq++, Buffer.from([0x02])), + packet(seq++, varStringColumn("a")), + packet(seq++, varStringColumn("a")), + packet(seq++, Buffer.concat([lenencStr(FIRST), lenencStr(LAST)])), + okPacket(seq++, 0xfe), + ]); + } + function stmtPrepareOK(startSeq: number, stmtId: number): Buffer { + let seq = startSeq; + return Buffer.concat([ + packet( + seq++, + Buffer.concat([Buffer.from([0x00]), u32le(stmtId), u16le(2), u16le(0), Buffer.from([0x00]), u16le(0)]), + ), + packet(seq++, varStringColumn("a")), + packet(seq++, varStringColumn("a")), + ]); + } + // Binary row: 0x00 header, 1-byte NULL bitmap (2 columns + 2 reserved bits + // fit in one byte), then each value as a length-encoded string. + function binaryResultSet(startSeq: number): Buffer { + let seq = startSeq; + return Buffer.concat([ + packet(seq++, Buffer.from([0x02])), + packet(seq++, varStringColumn("a")), + packet(seq++, varStringColumn("a")), + packet(seq++, Buffer.concat([Buffer.from([0x00]), Buffer.from([0x00]), lenencStr(FIRST), lenencStr(LAST)])), + okPacket(seq++, 0xfe), + ]); + } + + function startMockServer(): net.Server { + const server = net.createServer(socket => { + let buffered = Buffer.alloc(0); + let authed = false; + let stmtId = 0; + socket.write(handshakeV10()); + socket.on("data", chunk => { + buffered = Buffer.concat([buffered, chunk]); + while (buffered.length >= 4) { + const len = buffered[0] | (buffered[1] << 8) | (buffered[2] << 16); + if (buffered.length < 4 + len) break; + const seq = buffered[3]; + const payload = buffered.subarray(4, 4 + len); + buffered = buffered.subarray(4 + len); + if (!authed) { + authed = true; + socket.write(okPacket(seq + 1)); + continue; + } + const cmd = payload[0]; + if (cmd === 0x16 /* COM_STMT_PREPARE */) { + socket.write(stmtPrepareOK(seq + 1, ++stmtId)); + } else if (cmd === 0x17 /* COM_STMT_EXECUTE */) { + socket.write(binaryResultSet(seq + 1)); + } else if (cmd === 0x03 /* COM_QUERY */) { + socket.write(textResultSet(seq + 1)); + } else if (cmd === 0x19 /* COM_STMT_CLOSE */) { + // no response expected + } else { + socket.end(); + } + } + }); + }); + server.listen(0, "127.0.0.1"); + return server; + } + + test("object rows keep the last occurrence; .values() keeps every cell", async () => { + const server = startMockServer(); + await once(server, "listening"); + const { port } = server.address() as net.AddressInfo; + try { + await using sql = new SQL({ url: `mysql://root@127.0.0.1:${port}/db`, max: 1 }); + + // Binary protocol (prepared statement): last one wins in object mode. + expect(await sql`select 1 as a, 2 as a`).toEqual([{ a: LAST }]); + // Text protocol decodes the same way. + expect(await sql`select 1 as a, 2 as a`.simple()).toEqual([{ a: LAST }]); + // `.values()` is positional and must keep both cells. + expect(await sql`select 1 as a, 2 as a`.values()).toEqual([[FIRST, LAST]]); + } finally { + await new Promise(r => server.close(() => r())); + } + }); +}); diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index f05e49a2385..608f8e432f3 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -12758,6 +12758,75 @@ test("rejects Postgres connection options containing null bytes", async () => { await ok.close(); }); +// Native createInstance argument validation (src/sql_jsc/shared/ +// connection_args.rs / query_args.rs) — the username/password/database and +// tls checks the JS layer delegates to the driver. Every error here is thrown +// by the native parser before any socket is created, so no server needs to be +// listening: the address below is never dialed. +describe("shared createInstance validation (no server)", () => { + const base = { adapter: "postgres", hostname: "127.0.0.1", port: 1, max: 1 } as const; + + // Credentials are written into the NUL-delimited Postgres StartupMessage as + // `key\0value\0`; a NUL byte inside one would inject extra parameters, so + // the native parser must refuse it before connecting. The rejection is an + // ERR_INVALID_ARG_TYPE TypeError, matching the MySQL adapter and the Zig + // reference (`PostgresSQLConnection.zig` `throwInvalidArguments`) — the + // previous port threw a plain code-less Error here. + test.concurrent.each(["username", "password", "database"] as const)( + "rejects %s containing null bytes", + async field => { + await using sql = new SQL({ ...base, username: "u", [field]: "a\0b" }); + // `Query` is a lazy thenable, so collect the rejection explicitly. + const err: any = await sql`select 1`.then( + () => null, + e => e, + ); + expect(err?.message).toBe(`${field} must not contain null bytes`); + expect(err?.code).toBe("ERR_INVALID_ARG_TYPE"); + expect(err instanceof TypeError).toBe(true); + }, + ); + + test.concurrent("SSL_CTX creation failure throws the structured BoringSSL error", async () => { + // An unparseable CA makes `SSL_CTX` creation fail synchronously inside + // createInstance, before any socket exists. The failure carries + // `code: "ERR_BORINGSSL"` like the MySQL adapter and the Zig reference + // (`err.toJS(globalObject)`) — the previous port threw a plain code-less + // Error with a static message. + await using sql = new SQL({ ...base, username: "u", tls: { ca: "not a pem" } }); + const err: any = await sql`select 1`.then( + () => null, + e => e, + ); + expect(err?.message).toBe("Invalid CA"); + expect(err?.code).toBe("ERR_BORINGSSL"); + }); + + test.concurrent("rejects tls that is neither a boolean nor an object", async () => { + // A truthy non-boolean/non-object upgrades sslMode to `require` in JS and + // reaches the native parser as-is. + await using sql = new SQL({ ...base, username: "u", tls: 1 as any }); + const err: any = await sql`select 1`.then( + () => null, + e => e, + ); + expect(err?.message).toBe("tls must be a boolean or an object"); + }); + + test.concurrent("rejects simple queries with parameters", async () => { + await using sql = new SQL({ ...base, username: "u" }); + // Query-handle creation fails before the pool ever connects. + const err: any = await sql + .unsafe("select $1", [1]) + .simple() + .then( + () => null, + e => e, + ); + expect(err?.message).toBe("simple query cannot have parameters"); + }); +}); + // A Postgres server controls two independent column counts: the // RowDescription's field list (which sizes the per-row cell buffer and the // cached row Structure) and each DataRow's own column count. When a DataRow