Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/sql/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ pub mod shared {
pub mod connection_flags;
#[path = "Data.rs"]
pub mod data;
#[path = "QueryStatus.rs"]
pub mod query_status;
#[path = "SQLQueryResultMode.rs"]
pub mod sql_query_result_mode;
#[path = "StatementStatus.rs"]
pub mod statement_status;

pub use column_identifier::ColumnIdentifier;
pub use connection_flags::ConnectionFlags;
Expand All @@ -31,8 +35,6 @@ pub mod mysql {
pub mod mysql_request;
#[path = "MySQLTypes.rs"]
pub mod mysql_types;
#[path = "QueryStatus.rs"]
pub mod query_status;
#[path = "SSLMode.rs"]
pub mod ssl_mode;
#[path = "StatusFlags.rs"]
Expand Down Expand Up @@ -118,11 +120,12 @@ pub mod mysql {
pub use crate::mysql::mysql_types::FieldType;
}

pub use crate::shared::query_status;
pub use crate::shared::query_status::Status as QueryStatus;
pub use auth_method::AuthMethod;
pub use capabilities::Capabilities;
pub use connection_state::ConnectionState;
pub use mysql_query_result::MySQLQueryResult;
pub use query_status::Status as QueryStatus;
pub use ssl_mode::SSLMode;
pub use status_flags::{StatusFlag, StatusFlags};
pub use tls_status::TLSStatus;
Expand Down
14 changes: 14 additions & 0 deletions src/sql/postgres/protocol/ErrorResponse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ impl ErrorResponse {
) -> Result<Self, AnyPostgresError> {
Self::decode_internal(NewReader { wrapped: context })
}

/// `NoticeResponse` decode: a declared length below 4 decodes as an empty
/// notice instead of failing, unlike `ErrorResponse`.
pub fn decode_notice_internal<Container: super::new_reader::ReaderContext>(
mut reader: NewReader<Container>,
) -> Result<Self, AnyPostgresError> {
let remaining_bytes = reader.length()?.saturating_sub(4);
if remaining_bytes > 0 {
return Ok(Self {
messages: FieldMessage::decode_list::<Container>(reader)?,
});
}
Ok(Self::default())
}
}

// `to_js` lives on an extension trait in the `bun_sql_jsc` crate.
11 changes: 8 additions & 3 deletions src/sql/postgres/protocol/FieldMessage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ pub enum FieldMessage {

impl fmt::Display for FieldMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.payload())
}
}

impl FieldMessage {
/// Every variant carries a single `bun.String` payload.
pub fn payload(&self) -> &String {
match self {
Comment thread
alii marked this conversation as resolved.
FieldMessage::Severity(s)
| FieldMessage::LocalizedSeverity(s)
Expand All @@ -47,12 +54,10 @@ impl fmt::Display for FieldMessage {
| FieldMessage::Constraint(s)
| FieldMessage::File(s)
| FieldMessage::Line(s)
| FieldMessage::Routine(s) => write!(f, "{s}"),
| FieldMessage::Routine(s) => s,
}
}
}

impl FieldMessage {
pub fn decode_list<Context: super::new_reader::ReaderContext>(
mut reader: NewReader<Context>,
) -> Result<Vec<FieldMessage>, AnyPostgresError> {
Expand Down
40 changes: 5 additions & 35 deletions src/sql/postgres/protocol/NoticeResponse.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,5 @@
use super::field_message::FieldMessage;
use super::new_reader::NewReader;
use crate::postgres::AnyPostgresError;

#[derive(Default)]
pub struct NoticeResponse {
pub messages: Vec<FieldMessage>,
}

// Vec<FieldMessage> drops each element (FieldMessage's Drop) and the buffer
// automatically, so no explicit Drop body is needed.

impl NoticeResponse {
pub fn decode_internal<Container: super::new_reader::ReaderContext>(
mut reader: NewReader<Container>,
) -> Result<Self, AnyPostgresError> {
let mut remaining_bytes = reader.length()?;
remaining_bytes = remaining_bytes.saturating_sub(4);

if remaining_bytes > 0 {
return Ok(Self {
messages: FieldMessage::decode_list::<Container>(reader)?,
});
}
Ok(Self::default())
}

pub fn decode<Container: super::new_reader::ReaderContext>(
context: Container,
) -> Result<Self, AnyPostgresError> {
Self::decode_internal(NewReader { wrapped: context })
}
}

// `to_js` lives as an extension-trait method in the bun_sql_jsc crate.
/// NoticeResponse has the same wire format as ErrorResponse — a length-prefixed
/// list of field messages — so it reuses the same type. Notices decode via
/// `decode_notice_internal`, which tolerates a declared length below 4
/// (decoding as empty) where `ErrorResponse` fails.
pub type NoticeResponse = crate::postgres::protocol::error_response::ErrorResponse;
File renamed without changes.
13 changes: 13 additions & 0 deletions src/sql/shared/StatementStatus.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum Status {
Pending,
Parsing,
Prepared,
Failed,
}

impl Status {
pub fn is_running(self) -> bool {
Comment thread
alii marked this conversation as resolved.
Outdated
self == Status::Parsing
}
}
4 changes: 4 additions & 0 deletions src/sql_jsc/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pub mod shared {
#[path = "CachedStructure.rs"]
pub mod cached_structure;

pub mod connection_ctor_args;

Comment thread
alii marked this conversation as resolved.
pub mod datetime_text;

#[path = "ObjectIterator.rs"]
Expand All @@ -28,6 +30,8 @@ pub mod shared {
#[path = "QueryBindingIterator.rs"]
pub mod query_binding_iterator;

pub mod query_ctor_args;

#[path = "SQLDataCell.rs"]
pub mod sql_data_cell;

Expand Down
86 changes: 13 additions & 73 deletions src/sql_jsc/mysql/JSMySQLConnection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use core::ffi::c_void;
use crate::jsc::{
CallFrame, EventLoopSqlExt as _, EventLoopTimer, EventLoopTimerState, EventLoopTimerTag,
GlobalRef, HasAutoFlush, JSGlobalObject, JSValue, JsCell, JsRef, JsResult, KeepAlive,
VirtualMachine, VirtualMachineSqlExt as _, api::server_config::SSLConfig,
codegen::js_mysql_connection as js, webcore::AutoFlusher,
VirtualMachine, VirtualMachineSqlExt as _, codegen::js_mysql_connection as js,
webcore::AutoFlusher,
};
use crate::shared::CachedStructure;
use bun_boringssl_sys as boringssl;
use crate::shared::connection_ctor_args::{self, ConnectionCtorArgs};
use bun_core::strings;
use bun_core::{TimespecMockMode, timespec};
use bun_ptr::{AsCtxPtr, BackRef, ParentRef};
Expand Down Expand Up @@ -470,85 +470,25 @@ impl JSMySQLConnection {
// no other live borrow in this scope.
let vm = global_object.bun_vm().as_mut();
let arguments = callframe.arguments();
let hostname_str = bun_core::OwnedString::new(arguments[0].to_bun_string(global_object)?);
let port = arguments[1].coerce::<i32>(global_object)?;

let username_str = bun_core::OwnedString::new(arguments[2].to_bun_string(global_object)?);
let password_str = bun_core::OwnedString::new(arguments[3].to_bun_string(global_object)?);
let database_str = bun_core::OwnedString::new(arguments[4].to_bun_string(global_object)?);
// TODO: update this to match MySQL.
let ssl_mode: SSLMode = match arguments[5].to_int32() {
0 => SSLMode::Disable,
1 => SSLMode::Prefer,
2 => SSLMode::Require,
3 => SSLMode::VerifyCa,
4 => SSLMode::VerifyFull,
_ => SSLMode::Disable,
let Some(args) = ConnectionCtorArgs::<SSLMode>::parse(global_object, &mut *vm, arguments)?
else {
return Ok(JSValue::ZERO);
};

let tls_object = arguments[6];

let mut tls_config: SSLConfig = SSLConfig::default();
let mut secure: Option<*mut uws::SslCtx> = None;
if ssl_mode != SSLMode::Disable {
tls_config = if tls_object.is_boolean() && tls_object.to_boolean() {
SSLConfig::default()
} else if tls_object.is_object() {
match SSLConfig::from_js(&mut *vm, global_object, tls_object) {
Ok(Some(c)) => c,
Ok(None) => SSLConfig::default(),
Err(_) => return Ok(JSValue::ZERO),
}
} else {
return Err(global_object
.throw_invalid_arguments(format_args!("tls must be a boolean or an object")));
};

if global_object.has_exception() {
drop(tls_config);
return Ok(JSValue::ZERO);
}

// We always request the cert so we can verify it and also we manually
// abort the connection if the hostname doesn't match. Built here so
// CA/cert errors throw synchronously, applied later by upgradeToTLS.
// Goes through the per-VM weak `SSLContextCache` so every pooled
// connection / reconnect shares one `SSL_CTX*` per distinct config.
let mut err = uws::create_bun_socket_error_t::none;
secure = vm
.ssl_ctx_cache()
.get_or_create_opts(&tls_config.as_usockets_for_client_verification(), &mut err);
if secure.is_none() {
drop(tls_config);
return Err(
global_object.throw_value(crate::jsc::create_bun_socket_error_to_js(
err,
global_object,
)),
);
}
}
// Covers `try arguments[7/8].toBunString()` and the null-byte rejection
// below. Ownership passes to `MySQLConnection.init` once `Box::new`
// succeeds — we null the locals at that point so the connect-fail path
// (which `deref()`s the connection) doesn't double-free.
let tls_guard = scopeguard::guard((secure, tls_config), |(s, cfg)| {
if let Some(s) = s {
// SAFETY: secure was created by ssl_ctx_cache; we own one ref until transferred.
unsafe { boringssl::SSL_CTX_free(s) };
}
drop(cfg);
});
let tls_guard = connection_ctor_args::guard_tls(args.secure, args.tls_config);

let options_str = bun_core::OwnedString::new(arguments[7].to_bun_string(global_object)?);
let path_str = bun_core::OwnedString::new(arguments[8].to_bun_string(global_object)?);

// `init` takes `Box<[u8]>` per field (each separately owned), so we
// copy each string into its own allocation. `options_buf` becomes an
// empty box.
let username: Box<[u8]> = Box::from(username_str.to_utf8_without_ref().slice());
let password: Box<[u8]> = Box::from(password_str.to_utf8_without_ref().slice());
let database: Box<[u8]> = Box::from(database_str.to_utf8_without_ref().slice());
let username: Box<[u8]> = Box::from(args.username_str.to_utf8_without_ref().slice());
let password: Box<[u8]> = Box::from(args.password_str.to_utf8_without_ref().slice());
let database: Box<[u8]> = Box::from(args.database_str.to_utf8_without_ref().slice());
let options: Box<[u8]> = Box::from(options_str.to_utf8_without_ref().slice());
let path: Box<[u8]> = Box::from(path_str.to_utf8_without_ref().slice());
let options_buf: Box<[u8]> = Box::default();
Expand Down Expand Up @@ -595,7 +535,7 @@ impl JSMySQLConnection {
options_buf,
tls_config,
secure,
ssl_mode,
args.ssl_mode,
allow_public_key_retrieval,
)),
auto_flusher: JsCell::new(AutoFlusher::default()),
Expand All @@ -616,7 +556,7 @@ impl JSMySQLConnection {
let this = ParentRef::from(core::ptr::NonNull::new(ptr).expect("heap::into_raw non-null"));

{
let hostname = hostname_str.to_utf8();
let hostname = args.hostname_str.to_utf8();

// MySQL always opens plain TCP first; STARTTLS adopts into the TLS
// group after the SSLRequest exchange.
Expand All @@ -636,7 +576,7 @@ impl JSMySQLConnection {
uws::DispatchKind::Mysql,
None,
hostname.slice(),
port,
args.port,
ptr,
false,
)
Expand Down
46 changes: 9 additions & 37 deletions src/sql_jsc/mysql/JSMySQLQuery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::jsc::{
self as jsc, CallFrame, JSGlobalObject, JSGlobalObjectSqlExt as _, JSValue, JsRef, JsResult,
VirtualMachine, VirtualMachineSqlExt as _,
};
use crate::shared::query_ctor_args::QueryCtorArgs;
use bun_jsc::JsCell;
use bun_ptr::{AsCtxPtr, BackRef, ParentRef};
use bun_sql::mysql::MySQLQueryResult;
Expand Down Expand Up @@ -95,43 +96,14 @@ impl JSMySQLQuery {
global_this: &JSGlobalObject,
callframe: &CallFrame,
) -> JsResult<JSValue> {
let arguments = callframe.arguments();
let mut args = jsc::call_frame::ArgumentsSlice::init(global_this.sql_vm(), arguments);
// defer args.deinit() — handled by Drop
let Some(query) = args.next_eat() else {
return Err(global_this.throw(format_args!("query must be a string")));
};
let Some(values) = args.next_eat() else {
return Err(global_this.throw(format_args!("values must be an array")));
};

if !query.is_string() {
return Err(global_this.throw(format_args!("query must be a string")));
}

if values.js_type() != jsc::JSType::Array {
return Err(global_this.throw(format_args!("values must be an array")));
}

let pending_value: JSValue = args.next_eat().unwrap_or(JSValue::UNDEFINED);
let columns: JSValue = args.next_eat().unwrap_or(JSValue::UNDEFINED);
let js_bigint: JSValue = args.next_eat().unwrap_or(JSValue::FALSE);
let js_simple: JSValue = args.next_eat().unwrap_or(JSValue::FALSE);

let bigint = js_bigint.is_boolean() && js_bigint.as_boolean();
let simple = js_simple.is_boolean() && js_simple.as_boolean();
if simple {
if values.get_length(global_this)? > 0 {
return Err(global_this
.throw_invalid_arguments(format_args!("simple query cannot have parameters")));
}
if query.get_length(global_this)? >= i32::MAX as u64 {
return Err(global_this.throw_invalid_arguments(format_args!("query is too long")));
}
}
if !pending_value.js_type().is_array_like() {
return Err(global_this.throw_invalid_argument_type("query", "pendingValue", "Array"));
}
let QueryCtorArgs {
query,
values,
pending_value,
columns,
bigint,
simple,
} = QueryCtorArgs::parse(global_this, callframe.arguments())?;

let this_ptr = bun_core::heap::into_raw(Box::new(Self {
this_value: JsCell::new(JsRef::empty()),
Expand Down
Loading
Loading