diff --git a/cargo-pgrx/README.md b/cargo-pgrx/README.md index 2f87a09a5..24baea91d 100644 --- a/cargo-pgrx/README.md +++ b/cargo-pgrx/README.md @@ -886,10 +886,17 @@ If you just want to look at the full extension schema that pgrx will generate, u $ cargo pgrx schema --help Generate extension schema files -Usage: cargo pgrx schema [OPTIONS] [PG_VERSION] +Usage: cargo pgrx schema [OPTIONS] [ARGS]... Arguments: - [PG_VERSION] Do you want to run against pg13, pg14, pg15, pg16, pg17, or pg18? + [ARGS]... First arg may be a PostgreSQL version label (`pg13`..`pg18`). + Remaining args are SQL item names to emit (functions, types, + enums, operators, aggregates, triggers, schemas, extension_sql + blocks). When item names are given, only those items and their + transitive dependencies are emitted, in install order, and + 'MODULE_PATHNAME' is substituted with '$libdir/' so + the output can be replayed directly. Names containing `::` are + matched as Rust paths to disambiguate. Options: -p, --package Package to build (see `cargo help pkgid`) @@ -905,10 +912,66 @@ Options: -o, --out A path to output a produced SQL file (default is `stdout`) -d, --dot A path to output a produced GraphViz DOT file --skip-build Skip building a fresh extension shared object + --no-alter-extension Don't emit `ALTER EXTENSION ... ADD ...` statements when + extracting specific items (see "Attaching Slices" below) -h, --help Print help -V, --version Print version ``` +### Emitting a Slice of the Schema + +Any positional arguments after an optional `pgXX` version label are treated as SQL item +names. The output is restricted to those items plus every dependency they need, in +install order. Names match against each entity's SQL-visible identifier: `name` for +functions, types, enums, aggregates, and schemas; `opname` for operators (for example +`===`); `function_name` for triggers; and `name` for `extension_sql!` blocks. A name +containing `::` is treated as a Rust path and matched against `full_path`, which is the +way to disambiguate collisions (for example two functions named `dup_fn` in different +modules). + +When item names are supplied, every occurrence of `'MODULE_PATHNAME'` in the generated +SQL is substituted with `'$libdir/'`, so the output can be replayed directly +into a database without relying on the extension's control file to resolve +`MODULE_PATHNAME`. + +```shell +# Emit one function and its dependencies, with MODULE_PATHNAME substituted +cargo pgrx schema my_function + +# Multiple items at once +cargo pgrx schema my_function MyType === + +# Specify a Postgres version first, then the items +cargo pgrx schema pg18 my_function MyType + +# Disambiguate with a Rust path when the bare name matches multiple items +cargo pgrx schema my_crate::submodule::dup_fn + +# Write the slice to a file instead of stdout (combines with item selection) +cargo pgrx schema --out /tmp/extracted_schema_objects.sql my_function MyType +``` + +#### Attaching Slices to an Already-Installed Extension + +When item names are supplied, the emitted slice is wrapped in `BEGIN;`/`COMMIT;` +and every created object is followed by an `ALTER EXTENSION "" ADD ...` +statement. Piping the output into a database where the extension is already +installed makes the new objects members of the extension, verifiable via +`pg_depend`. + +```shell +# Add a new function to an already-installed extension +cargo pgrx schema my_new_fn | cargo pgrx connect +``` + +Pass `--no-alter-extension` to opt out of this (for example, to generate SQL +for hand-editing or to match the pre-feature output). + +`extension_sql!()` blocks that don't declare `creates = [...]` cannot be +attached automatically; the emitter prints a warning to stderr naming the +block's `file:line` so the user knows which objects to attach by hand. + + ## Extension Version Upgrade Scripts When creating a pgrx extension using `cargo pgrx new foo`, the new extension template directory tree includes a diff --git a/cargo-pgrx/src/command/install.rs b/cargo-pgrx/src/command/install.rs index 1d950b448..c88621f37 100644 --- a/cargo-pgrx/src/command/install.rs +++ b/cargo-pgrx/src/command/install.rs @@ -388,6 +388,10 @@ fn copy_sql_files( None, None, skip_build, + None, + // install scripts run inside CREATE EXTENSION and auto-attach; + // explicit ALTER EXTENSION would be redundant. + false, output_tracking, )?; } diff --git a/cargo-pgrx/src/command/schema.rs b/cargo-pgrx/src/command/schema.rs index 0790ce460..ab32c5bbf 100644 --- a/cargo-pgrx/src/command/schema.rs +++ b/cargo-pgrx/src/command/schema.rs @@ -16,7 +16,7 @@ use cargo_toml::Manifest; use eyre::WrapErr; use owo_colors::OwoColorize; use pgrx_pg_config::cargo::PgrxManifestExt; -use pgrx_pg_config::{Pgrx, get_target_dir}; +use pgrx_pg_config::{Pgrx, get_target_dir, is_supported_major_version}; use pgrx_sql_entity_graph::section::decode_entities; use pgrx_sql_entity_graph::{ControlFile, PgrxSql, SqlGraphEntity}; use std::path::{Path, PathBuf}; @@ -34,8 +34,16 @@ pub(crate) struct Schema { /// Build in test mode (for `cargo pgrx test`) #[clap(long)] test: bool, - /// Do you want to run against pg13, pg14, pg15, pg16, pg17, or pg18? - pg_version: Option, + /// Positional arguments. + /// + /// The first may be a PostgreSQL version label (`pg13`..`pg18`); every + /// remaining value is an SQL item name to emit (functions, types, + /// enums, operators, aggregates, triggers, schemas, extension_sql + /// blocks). Only those items and their transitive dependencies are + /// emitted, in install order, and `'MODULE_PATHNAME'` is substituted + /// with `'$libdir/'` so the output can be replayed directly. + /// Names containing `::` are matched as Rust paths to disambiguate. + args: Vec, /// Compile for release mode (default is debug) #[clap(long, short)] release: bool, @@ -60,6 +68,12 @@ pub(crate) struct Schema { /// Skip building a fresh extension shared object. #[clap(long)] skip_build: bool, + /// Don't emit `ALTER EXTENSION ... ADD ...` statements when extracting + /// specific items. By default, item mode emits ALTER EXTENSION so the + /// output can be piped into a running database and attached to the + /// already-installed extension. + #[clap(long)] + no_alter_extension: bool, } impl CommandExecute for Schema { @@ -76,6 +90,8 @@ impl CommandExecute for Schema { } }; + let (pg_version, items) = split_positional_args(&self.args); + let pgrx = Pgrx::from_config()?; let (package_manifest, package_manifest_path) = get_package_manifest( &self.features, @@ -86,7 +102,7 @@ impl CommandExecute for Schema { let (_pg_config, _pg_version) = pg_config_and_version( &pgrx, &package_manifest, - self.pg_version.clone(), + pg_version, Some(&mut self.features), true, )?; @@ -96,6 +112,7 @@ impl CommandExecute for Schema { if self.release { CargoProfile::Release } else { CargoProfile::Dev }, )?; + let attach = !self.no_alter_extension; generate_schema( self.manifest_path.as_deref(), self.package.as_deref(), @@ -108,11 +125,33 @@ impl CommandExecute for Schema { self.dot.as_deref(), log_level, self.skip_build, + items, + attach, &mut vec![], ) } } +/// Split the schema command's positional arguments into an optional +/// `pgXX` version label and an optional list of SQL item names. +/// +/// If the first argument parses as a supported PostgreSQL major version it +/// is consumed as `pg_version`; everything after it (or everything, if +/// there is no version) flows through as item names. `None` items means +/// the caller supplied no names at all — as distinct from an empty slice. +fn split_positional_args(args: &[String]) -> (Option, Option<&[String]>) { + let (pg_version, rest) = if let Some((first, rest)) = args.split_first() + && let Some(major) = first.strip_prefix("pg") + && let Ok(major) = major.parse::() + && is_supported_major_version(major) + { + (Some(first.clone()), rest) + } else { + (None, args) + }; + (pg_version, (!rest.is_empty()).then_some(rest)) +} + #[tracing::instrument(level = "error", skip_all, fields( profile = ?profile, test = is_test, @@ -132,6 +171,8 @@ pub(crate) fn generate_schema_for_cli( dot: Option<&Path>, log_level: Option, skip_build: bool, + items: Option<&[String]>, + attach: bool, output_tracking: &mut Vec, ) -> eyre::Result<()> { let manifest = Manifest::from_path(package_manifest_path)?; @@ -160,6 +201,8 @@ pub(crate) fn generate_schema_for_cli( target, path, dot, + items, + attach, output_tracking, manifest, ) @@ -172,10 +215,12 @@ pub(crate) fn generate_schema_implicit( target: Option<&str>, path: Option<&Path>, dot: Option<&Path>, + items: Option<&[String]>, + attach: bool, output_tracking: &mut Vec, manifest: cargo_toml::Manifest, ) -> eyre::Result<()> { - let (control_file_path, _extname) = find_control_file(package_manifest_path)?; + let (control_file_path, extname) = find_control_file(package_manifest_path)?; let lib_name = manifest.lib_name()?; let lib_filename = manifest.lib_filename()?; let versioned_so = get_property(package_manifest_path, "module_pathname")?.is_none(); @@ -206,7 +251,36 @@ pub(crate) fn generate_schema_implicit( let pgrx_sql = PgrxSql::build(entities.into_iter(), lib_name.to_string(), versioned_so) .wrap_err("SQL generation error")?; - if let Some(path) = path { + if let Some(items) = items { + let extension_name = attach.then_some(extname.as_str()); + let sliced = pgrx_sql + .to_sql_for_items(items, &lib_name, extension_name) + .wrap_err("Could not generate SQL for requested items")?; + if let Some(path) = path { + eprintln!( + "{} SQL for {} item(s) to {}", + " Writing".bold().green(), + items.len(), + path.display() + ); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + std::fs::write(path, sliced) + .wrap_err_with(|| format!("Could not write SQL to {}", path.display()))?; + } else { + eprintln!( + "{} SQL for {} item(s) to {}", + " Writing".bold().green(), + items.len(), + "/dev/stdout" + ); + use std::io::Write as _; + std::io::stdout() + .write_all(sliced.as_bytes()) + .wrap_err("Could not write SQL to stdout")?; + } + } else if let Some(path) = path { eprintln!("{} SQL entities to {}", " Writing".bold().green(), path.display()); pgrx_sql .to_file(path) @@ -335,7 +409,11 @@ fn first_build( #[cfg(test)] mod tests { - use super::decode_section_entities; + use super::{decode_section_entities, split_positional_args}; + + fn strs(args: &[&str]) -> Vec { + args.iter().map(|s| (*s).to_owned()).collect() + } #[test] fn test_missing_schema_section_errors() { @@ -346,4 +424,44 @@ mod tests { let error = decode_section_entities(&bin).expect_err("missing section"); assert!(error.to_string().contains("no embedded pgrx schema section found")); } + + #[test] + fn empty_args_yield_no_version_and_no_items() { + let args = strs(&[]); + let (pg, items) = split_positional_args(&args); + assert!(pg.is_none()); + assert!(items.is_none()); + } + + #[test] + fn version_alone_is_captured() { + let args = strs(&["pg18"]); + let (pg, items) = split_positional_args(&args); + assert_eq!(pg.as_deref(), Some("pg18")); + assert!(items.is_none()); + } + + #[test] + fn version_followed_by_items() { + let args = strs(&["pg18", "sum_vec", "MyType", "==="]); + let (pg, items) = split_positional_args(&args); + assert_eq!(pg.as_deref(), Some("pg18")); + assert_eq!(items, Some(&["sum_vec".to_owned(), "MyType".to_owned(), "===".to_owned()][..])); + } + + #[test] + fn items_only_without_version() { + let args = strs(&["sum_vec", "MyType", "==="]); + let (pg, items) = split_positional_args(&args); + assert!(pg.is_none()); + assert_eq!(items, Some(&["sum_vec".to_owned(), "MyType".to_owned(), "===".to_owned()][..])); + } + + #[test] + fn first_arg_that_looks_like_version_but_isnt_is_an_item() { + let args = strs(&["pgfoo", "sum_vec"]); + let (pg, items) = split_positional_args(&args); + assert!(pg.is_none()); + assert_eq!(items, Some(&["pgfoo".to_owned(), "sum_vec".to_owned()][..])); + } } diff --git a/pgrx-sql-entity-graph/src/aggregate/entity.rs b/pgrx-sql-entity-graph/src/aggregate/entity.rs index 35d176415..489e01047 100644 --- a/pgrx-sql-entity-graph/src/aggregate/entity.rs +++ b/pgrx-sql-entity-graph/src/aggregate/entity.rs @@ -24,6 +24,7 @@ use crate::to_sql::ToSql; use crate::to_sql::entity::ToSqlConfigEntity; use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity}; use eyre::{WrapErr, eyre}; +use petgraph::graph::NodeIndex; #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct AggregateTypeEntity<'a> { @@ -184,6 +185,39 @@ fn aggregate_sql_type(mapping: &SqlMapping, composite_type: Option<&str>) -> eyr } } +/// Render the positional argument-type signature for an aggregate as it +/// would appear inside `ALTER EXTENSION … ADD AGGREGATE name(…)`. For +/// ordered-set aggregates the rendering is `(direct ORDER BY args)`; +/// otherwise it is `(args)`. Matches the shape produced by +/// `PgAggregateEntity::to_sql`. +pub(crate) fn render_aggregate_argtypes( + context: &PgrxSql, + owner: NodeIndex, + a: &PgAggregateEntity, +) -> eyre::Result { + let render_slot = |arg: &AggregateTypeEntity| -> eyre::Result { + let slot = arg.name.unwrap_or("aggregate argument"); + let prefix = context.schema_prefix_for_used_type(&owner, slot, &arg.used_ty)?; + let sql = match arg.used_ty.metadata.argument_sql { + Ok(ref mapping) => aggregate_sql_type(mapping, arg.used_ty.composite_type)?, + Err(err) => return Err(err.into()), + }; + let variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" }; + Ok(format!("{variadic}{prefix}{sql}")) + }; + + let args = a.args.iter().map(render_slot).collect::>>()?.join(", "); + let direct = a.direct_args.as_deref().unwrap_or(&[]); + + if a.ordered_set { + let direct_rendered = + direct.iter().map(render_slot).collect::>>()?.join(", "); + Ok(format!("({direct_rendered} ORDER BY {args})")) + } else { + Ok(format!("({args})")) + } +} + impl ToSql for PgAggregateEntity<'_> { fn to_sql(&self, context: &PgrxSql) -> eyre::Result { let self_index = context.aggregates[self]; diff --git a/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs b/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs index 9c2d8557a..df2ffc1b7 100644 --- a/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs +++ b/pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs @@ -25,6 +25,7 @@ pub use cast::PgCastEntity; pub use operator::PgOperatorEntity; pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem}; +use crate::UsedTypeEntity; use crate::fmt; use crate::metadata::{Returns, SqlArrayMapping, SqlMapping}; use crate::pgrx_sql::PgrxSql; @@ -33,6 +34,7 @@ use crate::to_sql::entity::ToSqlConfigEntity; use crate::{ExternArgs, SqlGraphEntity, SqlGraphIdentifier}; use eyre::{WrapErr, eyre}; +use petgraph::graph::NodeIndex; /// The output of a [`PgExtern`](crate::pg_extern::PgExtern) from `quote::ToTokens::to_tokens`. #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] @@ -112,6 +114,84 @@ fn sql_type(mapping: &SqlMapping, composite_type: Option<&str>) -> eyre::Result< } } +/// Render the SQL spelling of one `UsedType`, with schema prefix applied. +/// This is the bit that sits right after `"name" VARIADIC ` in a CREATE +/// FUNCTION signature. +pub(crate) fn render_used_type_sql( + context: &PgrxSql, + owner: NodeIndex, + slot: &str, + used_ty: &UsedTypeEntity, +) -> eyre::Result { + let schema_prefix = context.schema_prefix_for_used_type(&owner, slot, used_ty)?; + let body = match used_ty.metadata.argument_sql { + Ok(SqlMapping::As(ref sql)) => sql.clone(), + Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => { + sql_type(mapping, used_ty.composite_type)? + } + Ok(SqlMapping::Skip) => { + return Err(eyre!("Found a skipped SQL type where SQL should be emitted")); + } + Err(err) => return Err(err.into()), + }; + Ok(format!("{schema_prefix}{body}")) +} + +/// Render the comma-separated argument-type list for a pg_extern, matching +/// the positional shape of CREATE FUNCTION but without names or defaults. +/// Skipped args are filtered out; variadic args get a `VARIADIC ` prefix. +pub(crate) fn render_function_argtypes( + context: &PgrxSql, + owner: NodeIndex, + f: &PgExternEntity, +) -> eyre::Result { + let mut pieces = Vec::new(); + for arg in f.fn_args.iter().filter(|a| a.used_ty.emits_argument_sql()) { + let slot = format!("argument `{}`", arg.pattern); + let rendered = render_used_type_sql(context, owner, &slot, &arg.used_ty)?; + if arg.used_ty.variadic { + pieces.push(format!("VARIADIC {rendered}")); + } else { + pieces.push(rendered); + } + } + Ok(pieces.join(", ")) +} + +/// Render the return type of a pg_extern for contexts that need a plain +/// scalar, such as CREATE CAST. Errors if the function returns a set or a +/// table. +pub(crate) fn render_function_return_type( + context: &PgrxSql, + owner: NodeIndex, + f: &PgExternEntity, +) -> eyre::Result { + let ty = match &f.fn_return { + PgExternReturnEntity::Type { ty } => ty, + PgExternReturnEntity::None => { + return Err(eyre!("Cannot render return type for a function with no return")); + } + other => { + return Err(eyre!("Cannot render a scalar return type for {other:?}")); + } + }; + let schema_prefix = context.schema_prefix_for_used_type(&owner, "return type", ty)?; + let body = match &ty.metadata.return_sql { + Ok(Returns::One(SqlMapping::As(sql))) => sql.clone(), + Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => { + sql_type(mapping, ty.composite_type)? + } + Ok(Returns::One(SqlMapping::Skip)) => { + return Err(eyre!("Return type was SqlMapping::Skip")); + } + Ok(other) => { + return Err(eyre!("Return type is not a scalar: {other:?}")); + } + Err(err) => return Err((*err).into()), + }; + Ok(format!("{schema_prefix}{body}")) +} + impl ToSql for PgExternEntity<'_> { fn to_sql(&self, context: &PgrxSql) -> eyre::Result { let self_index = context.externs[self]; @@ -418,45 +498,23 @@ impl ToSql for PgExternEntity<'_> { .fn_args .first() .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?; - let left_arg_schema_prefix = context.schema_prefix_for_used_type( - &self_index, + let left_arg_sql = render_used_type_sql( + context, + self_index, "operator left argument", &left_arg.used_ty, )?; - let left_arg_sql = match left_arg.used_ty.metadata.argument_sql { - Ok(SqlMapping::As(ref sql)) => sql.clone(), - Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => { - sql_type(mapping, left_arg.used_ty.composite_type)? - } - Ok(SqlMapping::Skip) => { - return Err(eyre!( - "Found an skipped SQL type in an operator, this is not valid" - )); - } - Err(err) => return Err(err.into()), - }; let right_arg = self .fn_args .get(1) .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?; - let right_arg_schema_prefix = context.schema_prefix_for_used_type( - &self_index, + let right_arg_sql = render_used_type_sql( + context, + self_index, "operator right argument", &right_arg.used_ty, )?; - let right_arg_sql = match right_arg.used_ty.metadata.argument_sql { - Ok(SqlMapping::As(ref sql)) => sql.clone(), - Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => { - sql_type(mapping, right_arg.used_ty.composite_type)? - } - Ok(SqlMapping::Skip) => { - return Err(eyre!( - "Found an skipped SQL type in an operator, this is not valid" - )); - } - Err(err) => return Err(err.into()), - }; let schema = self .schema @@ -469,16 +527,14 @@ impl ToSql for PgExternEntity<'_> { -- {module_path}::{name}\n\ CREATE OPERATOR {schema}{opname} (\n\ \tPROCEDURE={schema}\"{name}\",\n\ - \tLEFTARG={schema_prefix_left}{left_arg_sql}, /* {left_name} */\n\ - \tRIGHTARG={schema_prefix_right}{right_arg_sql}{maybe_comma} /* {right_name} */\n\ + \tLEFTARG={left_arg_sql}, /* {left_name} */\n\ + \tRIGHTARG={right_arg_sql}{maybe_comma} /* {right_name} */\n\ {optionals}\ );\ ", opname = op.opname.unwrap(), left_name = left_arg.used_ty.full_path, right_name = right_arg.used_ty.full_path, - schema_prefix_left = left_arg_schema_prefix, - schema_prefix_right = right_arg_schema_prefix, maybe_comma = if !optionals.is_empty() { "," } else { "" }, optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" @@ -489,47 +545,23 @@ impl ToSql for PgExternEntity<'_> { ext_sql += &operator_sql }; if let Some(cast) = &self.cast { - let target_fn_arg = &self.fn_return; - let target_ty = match target_fn_arg { + let target_ty = match &self.fn_return { PgExternReturnEntity::Type { ty } => ty, other => { return Err(eyre!("Casts must return a plain type, got: {other:?}")); } }; - let target_arg_schema_prefix = - context.schema_prefix_for_used_type(&self_index, "cast target type", target_ty)?; - let target_arg_sql = match &target_ty.metadata.return_sql { - Ok(Returns::One(SqlMapping::As(sql))) => sql.clone(), - Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => { - sql_type(mapping, target_ty.composite_type)? - } - Ok(Returns::One(SqlMapping::Skip)) => { - return Err(eyre!("Found an skipped SQL type in a cast, this is not valid")); - } - Err(err) => return Err((*err).into()), - Ok(other) => { - return Err(eyre!("Casts must return a plain SQL type, got: {other:?}")); - } - }; + let target_arg_sql = render_function_return_type(context, self_index, self)?; let source_arg = self .fn_args .first() .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?; - let source_arg_schema_prefix = context.schema_prefix_for_used_type( - &self_index, + let source_arg_sql = render_used_type_sql( + context, + self_index, "cast source type", &source_arg.used_ty, )?; - let source_arg_sql = match source_arg.used_ty.metadata.argument_sql { - Ok(SqlMapping::As(ref sql)) => sql.clone(), - Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => { - sql_type(mapping, source_arg.used_ty.composite_type)? - } - Ok(SqlMapping::Skip) => { - return Err(eyre!("Found an skipped SQL type in a cast, this is not valid")); - } - Err(err) => return Err(err.into()), - }; let optional = match cast { PgCastEntity::Default => String::from(""), PgCastEntity::Assignment => String::from(" AS ASSIGNMENT"), @@ -541,9 +573,9 @@ impl ToSql for PgExternEntity<'_> { -- {file}:{line}\n\ -- {module_path}::{name}\n\ CREATE CAST (\n\ - \t{schema_prefix_source}{source_arg_sql} /* {source_name} */\n\ + \t{source_arg_sql} /* {source_name} */\n\ \tAS\n\ - \t{schema_prefix_target}{target_arg_sql} /* {target_name} */\n\ + \t{target_arg_sql} /* {target_name} */\n\ )\n\ WITH FUNCTION {function_name}{optional};\ ", @@ -551,9 +583,7 @@ impl ToSql for PgExternEntity<'_> { line = self.line, name = self.name, module_path = self.module_path, - schema_prefix_source = source_arg_schema_prefix, source_name = source_arg.used_ty.full_path, - schema_prefix_target = target_arg_schema_prefix, target_name = target_ty.full_path, function_name = self.name, ); diff --git a/pgrx-sql-entity-graph/src/pgrx_sql.rs b/pgrx-sql-entity-graph/src/pgrx_sql.rs index c031a6cba..b1f4f6dc0 100644 --- a/pgrx-sql-entity-graph/src/pgrx_sql.rs +++ b/pgrx-sql-entity-graph/src/pgrx_sql.rs @@ -21,7 +21,8 @@ use petgraph::dot::Dot; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; use petgraph::visit::EdgeRef; -use std::collections::{BTreeMap, HashMap}; +use petgraph::Direction; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::fmt::Debug; use std::path::Path; @@ -573,6 +574,425 @@ impl<'a> PgrxSql<'a> { pub fn find_matching_fn(&self, name: &str) -> Option<&PgExternEntity<'a>> { self.externs.keys().find(|key| key.full_path.ends_with(name)) } + + /// Resolve a single user-supplied item name to one graph node. + /// + /// A match is any entity whose SQL-visible name, Rust path, or operator + /// symbol equals `name` exactly. A `::`-bearing argument is treated as a + /// Rust path (matched only against `full_path`). Ambiguous hits are a + /// hard error. + pub fn resolve_item(&self, name: &str) -> eyre::Result { + let by_path = name.contains("::"); + let mut matches: Vec<(NodeIndex, String)> = Vec::new(); + + for (entity, &idx) in &self.externs { + let fn_hit = if by_path { + entity.full_path == name + } else { + entity.name == name || entity.unaliased_name == name + }; + if fn_hit { + matches.push((idx, format!("function `{}`", entity.full_path))); + } + if !by_path + && let Some(op) = &entity.operator + && op.opname == Some(name) + && !matches.iter().any(|(existing, _)| *existing == idx) + { + matches.push(( + idx, + format!("operator `{}` on `{}`", name, entity.full_path), + )); + } + } + + for (entity, &idx) in &self.types { + let hit = + if by_path { entity.full_path == name } else { entity.name == name }; + if hit { + matches.push((idx, format!("type `{}`", entity.full_path))); + } + } + + for (entity, &idx) in &self.enums { + let hit = + if by_path { entity.full_path == name } else { entity.name == name }; + if hit { + matches.push((idx, format!("enum `{}`", entity.full_path))); + } + } + + for (entity, &idx) in &self.aggregates { + let hit = + if by_path { entity.full_path == name } else { entity.name == name }; + if hit { + matches.push((idx, format!("aggregate `{}`", entity.full_path))); + } + } + + for (entity, &idx) in &self.triggers { + let hit = if by_path { + entity.full_path == name + } else { + entity.function_name == name + }; + if hit { + matches.push((idx, format!("trigger `{}`", entity.full_path))); + } + } + + for (entity, &idx) in &self.extension_sqls { + if !by_path && entity.name == name { + matches.push((idx, format!("extension_sql `{}`", entity.name))); + continue; + } + for declared in &entity.creates { + let declared_name = match declared { + SqlDeclaredEntity::Type(data) | SqlDeclaredEntity::Enum(data) => { + data.name.as_str() + } + SqlDeclaredEntity::Function(data) => data.name.as_str(), + }; + if declared_name == name { + matches.push(( + idx, + format!( + "extension_sql `{}` (declares `{declared_name}`)", + entity.name + ), + )); + break; + } + } + } + + for (entity, &idx) in &self.schemas { + if !by_path && entity.name == name { + matches.push((idx, format!("schema `{}`", entity.name))); + } + } + + match matches.len() { + 0 => Err(eyre!("no SQL entity matches `{name}`")), + 1 => Ok(matches.remove(0).0), + _ => { + let labels = + matches.iter().map(|(_, l)| l.as_str()).collect::>().join(", "); + Err(eyre!( + "`{name}` is ambiguous; matched: {labels}. Disambiguate with a `::`-qualified Rust path." + )) + } + } + } + + /// Emit SQL for the given item names plus all transitive dependencies, in + /// dependency order, and substitute `'MODULE_PATHNAME'` with + /// `'$libdir/'` so the output can be replayed directly into a + /// database. + /// + /// When `extension_name` is `Some(name)`, the emitted slice is wrapped in + /// `BEGIN;`/`COMMIT;` and each created object is followed by an + /// `ALTER EXTENSION "" ADD …` clause so that piping the output into + /// a database where the extension is already installed attaches the new + /// objects to the extension. When `None`, the pre-feature behavior is + /// used (no transaction wrapping, no ADD clauses). + /// + /// Warnings (e.g. for `extension_sql!()` blocks without `creates = [...]`) + /// are written to stderr. Use `emit_slice_with_warnings` directly if you + /// need to capture them. + pub fn to_sql_for_items( + &self, + item_names: &[String], + lib_name: &str, + extension_name: Option<&str>, + ) -> eyre::Result { + self.emit_slice_with_warnings(item_names, lib_name, extension_name, |msg| { + eprintln!("{msg}"); + }) + } + + /// Core of [`Self::to_sql_for_items`]. Takes a warning sink so tests + /// (and future non-stderr callers) can observe the diagnostics that + /// would otherwise go to stderr. + pub(crate) fn emit_slice_with_warnings( + &self, + item_names: &[String], + lib_name: &str, + extension_name: Option<&str>, + warn: W, + ) -> eyre::Result { + let mut targets = Vec::with_capacity(item_names.len()); + for name in item_names { + targets.push(self.resolve_item(name)?); + } + self.emit_slice_from_nodes(&targets, lib_name, extension_name, warn) + } + + /// Same as [`Self::emit_slice_with_warnings`] but takes already-resolved + /// node indices. Used by tests that need to target entities whose + /// resolution is ambiguous or not supported by [`Self::resolve_item`] + /// (e.g. `Ord` and `Hash` derives). + pub(crate) fn emit_slice_from_nodes( + &self, + targets: &[NodeIndex], + lib_name: &str, + extension_name: Option<&str>, + mut warn: W, + ) -> eyre::Result { + let keep = self.collect_transitive_deps(targets); + + let mut body = String::new(); + for nodes in petgraph::algo::tarjan_scc(&self.graph).iter().rev() { + let ordered = self.connected_component_emit_order(nodes); + let mut block = Vec::new(); + + for node in ordered { + if !keep.contains(&node) { + continue; + } + let ent = &self.graph[node]; + + // The ExtensionRoot's CREATE-phase output is a trivial comment + // block that reads "auto generated by pgrx". Inside a slice + // aimed at an already-installed extension it would be strange + // and confusing, so skip it. + if matches!(ent, SqlGraphEntity::ExtensionRoot(_)) { + continue; + } + + let create_sql = ent.to_sql(self)?; + let create_sql = create_sql.trim(); + + let mut piece = String::new(); + if !create_sql.is_empty() { + piece.push_str(create_sql); + piece.push('\n'); + } + + if let Some(ext) = extension_name { + match self.render_alter_extension_for_node(node, ext)? { + Some(alter_sql) => { + piece.push_str(&alter_sql); + if !alter_sql.ends_with('\n') { + piece.push('\n'); + } + } + None => { + if let SqlGraphEntity::CustomSql(c) = ent + && c.creates.is_empty() + { + warn(format!( + "warning: extension_sql block at {}:{} does not declare `creates = [...]`; its objects won't be attached to the extension automatically", + c.file, c.line, + )); + } + } + } + } + + if !piece.is_empty() { + block.push(piece); + } + } + + if !block.is_empty() { + body.push_str("/* */\n"); + body.push_str(&block.join("\n")); + body.push_str("/* */\n\n"); + } + } + + let replacement = format!("'$libdir/{lib_name}'"); + let body = body.replace("'MODULE_PATHNAME'", &replacement); + + Ok(match extension_name { + Some(_) => format!("BEGIN;\n\n{body}\nCOMMIT;\n"), + None => body, + }) + } + + /// Produce the `ALTER EXTENSION "" ADD …;` clauses for `node`, or + /// `Ok(None)` when the node is not an extension-attachable object + /// (builtin type, extension root, or a free-form `extension_sql!()` + /// block that didn't declare `creates = [...]`). + fn render_alter_extension_for_node( + &self, + node: NodeIndex, + extension_name: &str, + ) -> eyre::Result> { + let ent = &self.graph[node]; + let ext = extension_name; + + match ent { + SqlGraphEntity::Function(f) => { + let schema = f + .schema + .map(|s| format!("{s}.")) + .unwrap_or_else(|| self.schema_prefix_for(&node)); + let argtypes = crate::pg_extern::entity::render_function_argtypes(self, node, f)?; + let mut out = + format!("ALTER EXTENSION \"{ext}\" ADD FUNCTION {schema}\"{name}\"({argtypes});", + name = f.name); + + if let Some(op) = &f.operator + && let Some(opname) = op.opname + { + let left = f.fn_args.first().ok_or_else(|| { + eyre!("operator `{}` missing left argument", f.name) + })?; + let right = f.fn_args.get(1).ok_or_else(|| { + eyre!("operator `{}` missing right argument", f.name) + })?; + let left_sql = crate::pg_extern::entity::render_used_type_sql( + self, + node, + "operator left argument", + &left.used_ty, + )?; + let right_sql = crate::pg_extern::entity::render_used_type_sql( + self, + node, + "operator right argument", + &right.used_ty, + )?; + out.push('\n'); + out.push_str(&format!( + "ALTER EXTENSION \"{ext}\" ADD OPERATOR {schema}{opname}({left_sql}, {right_sql});" + )); + } + + if f.cast.is_some() { + let source = f.fn_args.first().ok_or_else(|| { + eyre!("cast `{}` missing source argument", f.name) + })?; + let source_sql = crate::pg_extern::entity::render_used_type_sql( + self, + node, + "cast source type", + &source.used_ty, + )?; + let target_sql = + crate::pg_extern::entity::render_function_return_type(self, node, f)?; + out.push('\n'); + out.push_str(&format!( + "ALTER EXTENSION \"{ext}\" ADD CAST ({source_sql} AS {target_sql});" + )); + } + + Ok(Some(out)) + } + SqlGraphEntity::Type(t) => { + let schema = self.schema_prefix_for(&node); + Ok(Some(format!( + "ALTER EXTENSION \"{ext}\" ADD TYPE {schema}{name};", + name = t.name + ))) + } + SqlGraphEntity::Enum(e) => { + let schema = self.schema_prefix_for(&node); + Ok(Some(format!( + "ALTER EXTENSION \"{ext}\" ADD TYPE {schema}{name};", + name = e.name + ))) + } + SqlGraphEntity::Aggregate(a) => { + let schema = self.schema_prefix_for(&node); + let argtypes = + crate::aggregate::entity::render_aggregate_argtypes(self, node, a)?; + Ok(Some(format!( + "ALTER EXTENSION \"{ext}\" ADD AGGREGATE {schema}\"{name}\"{argtypes};", + name = a.name + ))) + } + SqlGraphEntity::Trigger(t) => { + let schema = self.schema_prefix_for(&node); + Ok(Some(format!( + "ALTER EXTENSION \"{ext}\" ADD FUNCTION {schema}\"{name}\"();", + name = t.function_name + ))) + } + SqlGraphEntity::Ord(o) => { + // Unqualified names: matches `PostgresOrdEntity::to_sql`, which + // also emits `{name}_btree_ops` without a schema prefix. + Ok(Some(format!( + "ALTER EXTENSION \"{ext}\" ADD OPERATOR FAMILY {name}_btree_ops USING btree;\n\ + ALTER EXTENSION \"{ext}\" ADD OPERATOR CLASS {name}_btree_ops USING btree;", + name = o.name + ))) + } + SqlGraphEntity::Hash(h) => { + // Same unqualified-name rationale as Ord. + Ok(Some(format!( + "ALTER EXTENSION \"{ext}\" ADD OPERATOR FAMILY {name}_hash_ops USING hash;\n\ + ALTER EXTENSION \"{ext}\" ADD OPERATOR CLASS {name}_hash_ops USING hash;", + name = h.name + ))) + } + SqlGraphEntity::Schema(s) => { + if matches!(s.name, "public" | "pg_catalog") { + return Ok(None); + } + Ok(Some(format!( + "ALTER EXTENSION \"{ext}\" ADD SCHEMA {name};", + name = s.name + ))) + } + SqlGraphEntity::CustomSql(c) => { + if c.creates.is_empty() { + return Ok(None); + } + let mut out = String::new(); + for (idx, declared) in c.creates.iter().enumerate() { + if idx > 0 { + out.push('\n'); + } + match declared { + SqlDeclaredEntity::Type(data) => { + out.push_str(&format!( + "ALTER EXTENSION \"{ext}\" ADD TYPE {};", + data.sql + )); + } + SqlDeclaredEntity::Enum(data) => { + out.push_str(&format!( + "ALTER EXTENSION \"{ext}\" ADD TYPE {};", + data.sql + )); + } + SqlDeclaredEntity::Function(data) => { + out.push_str(&format!( + "ALTER EXTENSION \"{ext}\" ADD FUNCTION {};", + data.sql + )); + } + } + } + Ok(Some(out)) + } + SqlGraphEntity::BuiltinType(_) | SqlGraphEntity::ExtensionRoot(_) => Ok(None), + } + } + + /// Collect every node reachable from `targets` by walking edges backward + /// (i.e. every dependency that must exist before the targets can be + /// created). The returned set always contains the targets themselves. + fn collect_transitive_deps(&self, targets: &[NodeIndex]) -> HashSet { + let mut visited = HashSet::new(); + let mut queue = VecDeque::new(); + for &t in targets { + if visited.insert(t) { + queue.push_back(t); + } + } + while let Some(node) = queue.pop_front() { + for predecessor in self.graph.neighbors_directed(node, Direction::Incoming) { + if visited.insert(predecessor) { + queue.push_back(predecessor); + } + } + } + visited + } } fn build_base_edges<'a>( @@ -1825,10 +2245,18 @@ mod tests { use super::*; use crate::UsedTypeEntity; use crate::aggregate::entity::{AggregateTypeEntity, PgAggregateEntity}; - use crate::extension_sql::entity::{ExtensionSqlEntity, SqlDeclaredTypeEntityData}; + use crate::extension_sql::entity::{ + ExtensionSqlEntity, SqlDeclaredEntity, SqlDeclaredTypeEntityData, + }; use crate::extern_args::ExternArgs; use crate::metadata::{FunctionMetadataTypeEntity, Returns, SqlArrayMapping, SqlMapping}; - use crate::pg_extern::entity::{PgExternArgumentEntity, PgExternEntity, PgExternReturnEntity}; + use crate::pg_extern::entity::{ + PgExternArgumentEntity, PgExternEntity, PgExternReturnEntity, PgOperatorEntity, + }; + use crate::pg_trigger::entity::PgTriggerEntity; + use crate::postgres_enum::entity::PostgresEnumEntity; + use crate::postgres_hash::entity::PostgresHashEntity; + use crate::postgres_ord::entity::PostgresOrdEntity; use crate::postgres_type::entity::PostgresTypeEntity; use crate::schema::entity::SchemaEntity; use crate::to_sql::entity::ToSqlConfigEntity; @@ -2465,4 +2893,654 @@ mod tests { assert!(error.to_string().contains("MSTYPE")); assert!(error.to_string().contains("tests::BadMovingState")); } + + #[test] + fn to_sql_for_items_emits_only_targets_and_deps_with_lib_substitution() { + let hexint = extension_owned_type("tests::HexInt", "tests::HexInt", "hexint"); + let declared = declared_type_sql( + "tests", + "tests::concrete_type", + "concrete_type", + "tests::HexInt", + "tests::HexInt", + "hexint", + ); + let target = + function_entity("emit_me", vec![], PgExternReturnEntity::Type { ty: hexint.clone() }); + let unused = function_entity( + "leave_me_out", + vec![], + PgExternReturnEntity::Type { + ty: external_type("alloc::string::String", "alloc::string::String", "text"), + }, + ); + + let pgrx_sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::CustomSql(declared), + SqlGraphEntity::Function(target), + SqlGraphEntity::Function(unused), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let sliced = pgrx_sql + .to_sql_for_items(&["emit_me".into()], "myext", None) + .expect("slice emission should succeed"); + + assert!(sliced.contains("emit_me"), "target function missing:\n{sliced}"); + assert!(sliced.contains("CREATE TYPE custom_type;"), "transitive dep missing:\n{sliced}"); + assert!(!sliced.contains("leave_me_out"), "unrelated function leaked:\n{sliced}"); + assert!( + sliced.contains("'$libdir/myext'"), + "MODULE_PATHNAME should be substituted:\n{sliced}" + ); + assert!(!sliced.contains("'MODULE_PATHNAME'"), "raw placeholder remained:\n{sliced}"); + } + + #[test] + fn resolve_item_rejects_ambiguous_name_without_path() { + let dup_a = function_entity("dup_fn", vec![], PgExternReturnEntity::None); + let mut dup_b = function_entity("dup_fn", vec![], PgExternReturnEntity::None); + dup_b.module_path = "tests::other"; + dup_b.full_path = "tests::other::dup_fn"; + + let pgrx_sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::Function(dup_a), + SqlGraphEntity::Function(dup_b), + ] + .into_iter(), + "test".into(), + false, + ) + .unwrap(); + + let err = pgrx_sql.resolve_item("dup_fn").expect_err("ambiguous name should fail"); + let msg = err.to_string(); + assert!(msg.contains("ambiguous"), "expected ambiguity error, got: {msg}"); + assert!(msg.contains("tests::dup_fn"), "got: {msg}"); + assert!(msg.contains("tests::other::dup_fn"), "got: {msg}"); + + let unique = pgrx_sql + .resolve_item("tests::other::dup_fn") + .expect("qualified path should resolve"); + assert_eq!(pgrx_sql.graph[unique].rust_identifier(), "tests::other::dup_fn"); + } + + fn slice_with_warnings( + sql: &PgrxSql, + items: &[String], + lib_name: &str, + ext: Option<&str>, + ) -> (String, Vec) { + let mut warnings: Vec = Vec::new(); + let out = sql + .emit_slice_with_warnings(items, lib_name, ext, |msg| warnings.push(msg)) + .expect("slice emission should succeed"); + (out, warnings) + } + + fn slice_by_nodes( + sql: &PgrxSql, + targets: &[NodeIndex], + lib_name: &str, + ext: Option<&str>, + ) -> (String, Vec) { + let mut warnings: Vec = Vec::new(); + let out = sql + .emit_slice_from_nodes(targets, lib_name, ext, |msg| warnings.push(msg)) + .expect("slice emission should succeed"); + (out, warnings) + } + + fn trigger_entity(function_name: &'static str) -> PgTriggerEntity<'static> { + PgTriggerEntity { + function_name, + to_sql_config: to_sql_config(), + file: "test.rs", + line: 1, + module_path: "tests", + full_path: Box::leak(format!("tests::{function_name}").into_boxed_str()), + } + } + + fn ord_entity(name: &'static str) -> PostgresOrdEntity<'static> { + // full_path lives under `ord_for::` to avoid colliding with the + // underlying type's full_path (which is `tests::{name}`). The `name` + // field is what appears in CREATE OPERATOR FAMILY / CLASS. + PostgresOrdEntity { + name, + file: "test.rs", + line: 1, + full_path: Box::leak(format!("tests::ord_for::{name}").into_boxed_str()), + module_path: "tests::ord_for", + type_ident: Box::leak(format!("tests::{name}").into_boxed_str()), + to_sql_config: to_sql_config(), + } + } + + fn hash_entity(name: &'static str) -> PostgresHashEntity<'static> { + // Same disambiguation rationale as `ord_entity`. + PostgresHashEntity { + name, + file: "test.rs", + line: 1, + full_path: Box::leak(format!("tests::hash_for::{name}").into_boxed_str()), + module_path: "tests::hash_for", + type_ident: Box::leak(format!("tests::{name}").into_boxed_str()), + to_sql_config: to_sql_config(), + } + } + + fn enum_entity(name: &'static str) -> PostgresEnumEntity<'static> { + PostgresEnumEntity { + name, + file: "test.rs", + line: 1, + full_path: Box::leak(format!("tests::{name}").into_boxed_str()), + module_path: "tests", + type_ident: Box::leak(format!("tests::{name}").into_boxed_str()), + variants: vec!["red", "green", "blue"], + to_sql_config: to_sql_config(), + } + } + + #[test] + fn alter_extension_attaches_bare_function() { + let fun = function_entity("state_fn", vec![], PgExternReturnEntity::None); + let sql = PgrxSql::build( + vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(fun)] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, warnings) = + slice_with_warnings(&sql, &["state_fn".into()], "myext", Some("myext")); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + assert!(out.starts_with("BEGIN;"), "missing BEGIN:\n{out}"); + assert!(out.trim_end().ends_with("COMMIT;"), "missing COMMIT:\n{out}"); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD FUNCTION "state_fn"();"#), + "missing ADD FUNCTION:\n{out}" + ); + } + + #[test] + fn alter_extension_includes_argument_types() { + let arg_ty = external_type("alloc::string::String", "alloc::string::String", "text"); + let fun = function_entity( + "takes_text", + vec![PgExternArgumentEntity { pattern: "value", used_ty: arg_ty }], + PgExternReturnEntity::None, + ); + let sql = PgrxSql::build( + vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(fun)] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = slice_with_warnings(&sql, &["takes_text".into()], "myext", Some("myext")); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD FUNCTION "takes_text"(text);"#), + "missing argtype in ADD FUNCTION:\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_operator_in_addition_to_function() { + let arg_ty = external_type("alloc::string::String", "alloc::string::String", "text"); + let mut fun = function_entity( + "eq_ignoring_case", + vec![ + PgExternArgumentEntity { pattern: "lhs", used_ty: arg_ty.clone() }, + PgExternArgumentEntity { pattern: "rhs", used_ty: arg_ty }, + ], + PgExternReturnEntity::Type { + ty: external_type("bool", "bool", "bool"), + }, + ); + fun.operator = Some(PgOperatorEntity { + opname: Some("==="), + commutator: None, + negator: None, + restrict: None, + join: None, + hashes: false, + merges: false, + }); + + let sql = PgrxSql::build( + vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(fun)] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = + slice_with_warnings(&sql, &["eq_ignoring_case".into()], "myext", Some("myext")); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD FUNCTION "eq_ignoring_case"(text, text);"#), + "missing ADD FUNCTION:\n{out}" + ); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD OPERATOR ===(text, text);"#), + "missing ADD OPERATOR:\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_type_and_its_io_functions() { + let ty = type_entity("MyType", "tests::MyType", "tests::MyType"); + let in_fn = function_entity( + "in_fn", + vec![PgExternArgumentEntity { + pattern: "input", + used_ty: external_type("&core::ffi::CStr", "&core::ffi::CStr", "cstring"), + }], + PgExternReturnEntity::Type { + ty: used_type("tests::MyType", "tests::MyType", "MyType", TypeOrigin::ThisExtension), + }, + ); + let out_fn = function_entity( + "out_fn", + vec![PgExternArgumentEntity { + pattern: "input", + used_ty: used_type( + "tests::MyType", + "tests::MyType", + "MyType", + TypeOrigin::ThisExtension, + ), + }], + PgExternReturnEntity::Type { + ty: external_type("alloc::ffi::CString", "alloc::ffi::CString", "cstring"), + }, + ); + + let sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::Type(ty), + SqlGraphEntity::Function(in_fn), + SqlGraphEntity::Function(out_fn), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = + slice_with_warnings(&sql, &["tests::MyType".into()], "myext", Some("myext")); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD TYPE MyType;"#), + "missing ADD TYPE:\n{out}" + ); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD FUNCTION "in_fn"(cstring);"#), + "missing ADD FUNCTION for in_fn:\n{out}" + ); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD FUNCTION "out_fn"(MyType);"#), + "missing ADD FUNCTION for out_fn:\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_enum() { + let en = enum_entity("Color"); + let sql = PgrxSql::build( + vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Enum(en)] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = slice_with_warnings(&sql, &["Color".into()], "myext", Some("myext")); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD TYPE Color;"#), + "missing ADD TYPE:\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_aggregate_with_args() { + let stype = external_type("tests::State", "tests::State", "TEXT"); + let arg_ty = external_type("i32", "i32", "integer"); + let agg = aggregate_entity( + "sum_my", + vec![AggregateTypeEntity { used_ty: arg_ty, name: Some("value") }], + stype, + None, + ); + + let sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::Function(state_function()), + SqlGraphEntity::Aggregate(agg), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = slice_with_warnings(&sql, &["sum_my".into()], "myext", Some("myext")); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD AGGREGATE "sum_my"(integer);"#), + "missing ADD AGGREGATE:\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_trigger() { + let trig = trigger_entity("my_trig"); + let sql = PgrxSql::build( + vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Trigger(trig)] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = slice_with_warnings(&sql, &["my_trig".into()], "myext", Some("myext")); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD FUNCTION "my_trig"();"#), + "missing ADD FUNCTION (trigger):\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_ord_emits_family_and_class() { + // Build the underlying type + comparison functions so the Ord node + // has something to connect to. We don't assert on those; we only + // care that the Ord node's ALTER EXTENSION clauses come out right. + let mut ty = type_entity("Sortable", "tests::Sortable", "tests::Sortable"); + ty.in_fn_path = "sortable_in"; + ty.out_fn_path = "sortable_out"; + let text = external_type("alloc::string::String", "alloc::string::String", "text"); + let cstring = external_type("&core::ffi::CStr", "&core::ffi::CStr", "cstring"); + let in_fn = function_entity( + "sortable_in", + vec![PgExternArgumentEntity { pattern: "input", used_ty: cstring }], + PgExternReturnEntity::Type { + ty: used_type( + "tests::Sortable", + "tests::Sortable", + "Sortable", + TypeOrigin::ThisExtension, + ), + }, + ); + let out_fn = function_entity( + "sortable_out", + vec![PgExternArgumentEntity { + pattern: "input", + used_ty: used_type( + "tests::Sortable", + "tests::Sortable", + "Sortable", + TypeOrigin::ThisExtension, + ), + }], + PgExternReturnEntity::Type { ty: text }, + ); + let ord = ord_entity("Sortable"); + + let sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::Type(ty), + SqlGraphEntity::Function(in_fn), + SqlGraphEntity::Function(out_fn), + SqlGraphEntity::Ord(ord.clone()), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let ord_idx = sql.ords[&ord]; + let (out, _) = slice_by_nodes(&sql, &[ord_idx], "myext", Some("myext")); + assert!( + out.contains( + r#"ALTER EXTENSION "myext" ADD OPERATOR FAMILY Sortable_btree_ops USING btree;"# + ), + "missing ADD OPERATOR FAMILY:\n{out}" + ); + assert!( + out.contains( + r#"ALTER EXTENSION "myext" ADD OPERATOR CLASS Sortable_btree_ops USING btree;"# + ), + "missing ADD OPERATOR CLASS:\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_hash_emits_family_and_class() { + let mut ty = type_entity("Hashable", "tests::Hashable", "tests::Hashable"); + ty.in_fn_path = "hashable_in"; + ty.out_fn_path = "hashable_out"; + let text = external_type("alloc::string::String", "alloc::string::String", "text"); + let cstring = external_type("&core::ffi::CStr", "&core::ffi::CStr", "cstring"); + let in_fn = function_entity( + "hashable_in", + vec![PgExternArgumentEntity { pattern: "input", used_ty: cstring }], + PgExternReturnEntity::Type { + ty: used_type( + "tests::Hashable", + "tests::Hashable", + "Hashable", + TypeOrigin::ThisExtension, + ), + }, + ); + let out_fn = function_entity( + "hashable_out", + vec![PgExternArgumentEntity { + pattern: "input", + used_ty: used_type( + "tests::Hashable", + "tests::Hashable", + "Hashable", + TypeOrigin::ThisExtension, + ), + }], + PgExternReturnEntity::Type { ty: text }, + ); + let hash = hash_entity("Hashable"); + + let sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::Type(ty), + SqlGraphEntity::Function(in_fn), + SqlGraphEntity::Function(out_fn), + SqlGraphEntity::Hash(hash.clone()), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let hash_idx = sql.hashes[&hash]; + let (out, _) = slice_by_nodes(&sql, &[hash_idx], "myext", Some("myext")); + assert!( + out.contains( + r#"ALTER EXTENSION "myext" ADD OPERATOR FAMILY Hashable_hash_ops USING hash;"# + ), + "missing ADD OPERATOR FAMILY:\n{out}" + ); + assert!( + out.contains( + r#"ALTER EXTENSION "myext" ADD OPERATOR CLASS Hashable_hash_ops USING hash;"# + ), + "missing ADD OPERATOR CLASS:\n{out}" + ); + } + + #[test] + fn alter_extension_attaches_schema_but_skips_public() { + let schema = schema_entity("tests::my_schema", "my_schema"); + let fun_arg = external_type("i32", "i32", "integer"); + let mut fun = function_entity( + "my_fn", + vec![PgExternArgumentEntity { pattern: "x", used_ty: fun_arg }], + PgExternReturnEntity::None, + ); + fun.module_path = "tests::my_schema"; + fun.full_path = "tests::my_schema::my_fn"; + + let sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::Schema(schema), + SqlGraphEntity::Function(fun), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = slice_with_warnings( + &sql, + &["tests::my_schema::my_fn".into()], + "myext", + Some("myext"), + ); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD SCHEMA my_schema;"#), + "missing ADD SCHEMA:\n{out}" + ); + assert!( + !out.contains("ADD SCHEMA public"), + "should not emit ADD SCHEMA for public:\n{out}" + ); + } + + #[test] + fn alter_extension_custom_sql_with_creates_emits_add_type() { + let hexint = extension_owned_type("tests::HexInt", "tests::HexInt", "hexint"); + let declared = declared_type_sql( + "tests", + "tests::concrete_type", + "concrete_type", + "tests::HexInt", + "tests::HexInt", + "hexint", + ); + let target = + function_entity("uses_hexint", vec![], PgExternReturnEntity::Type { ty: hexint }); + + let sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::CustomSql(declared), + SqlGraphEntity::Function(target), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, warnings) = + slice_with_warnings(&sql, &["uses_hexint".into()], "myext", Some("myext")); + assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}"); + assert!( + out.contains(r#"ALTER EXTENSION "myext" ADD TYPE hexint;"#), + "missing ADD TYPE for declared type:\n{out}" + ); + } + + #[test] + fn alter_extension_custom_sql_without_creates_warns() { + let free_form = ExtensionSqlEntity { + module_path: "tests", + full_path: "tests::free_form_sql", + sql: "CREATE TABLE some_table(id INT);", + file: "somefile.rs", + line: 42, + name: "free_form_sql", + bootstrap: false, + finalize: false, + requires: vec![], + creates: vec![], + }; + + // Emit a function that transitively pulls the free-form block in + // through a `requires`. Simplest path: slice the free-form block + // directly by name. + let sql = PgrxSql::build( + vec![ + SqlGraphEntity::ExtensionRoot(control_file()), + SqlGraphEntity::CustomSql(free_form), + ] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, warnings) = + slice_with_warnings(&sql, &["free_form_sql".into()], "myext", Some("myext")); + assert!( + !out.contains(r#"ALTER EXTENSION "myext" ADD"#), + "free-form block should not emit ADD:\n{out}" + ); + assert_eq!(warnings.len(), 1, "expected one warning, got: {warnings:?}"); + assert!(warnings[0].contains("somefile.rs:42"), "warning missing file:line: {warnings:?}"); + assert!(warnings[0].contains("free-form") + || warnings[0].contains("creates"), "warning missing reason: {warnings:?}"); + } + + #[test] + fn no_alter_extension_mode_matches_pre_feature_output() { + let fun = function_entity("state_fn", vec![], PgExternReturnEntity::None); + let sql = PgrxSql::build( + vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(fun)] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = slice_with_warnings(&sql, &["state_fn".into()], "myext", None); + assert!(!out.contains("ALTER EXTENSION"), "unexpected ALTER EXTENSION:\n{out}"); + assert!(!out.contains("BEGIN;"), "unexpected BEGIN:\n{out}"); + assert!(!out.contains("COMMIT;"), "unexpected COMMIT:\n{out}"); + } + + #[test] + fn alter_extension_substitutes_module_pathname() { + let fun = function_entity("state_fn", vec![], PgExternReturnEntity::None); + let sql = PgrxSql::build( + vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(fun)] + .into_iter(), + "myext".into(), + false, + ) + .unwrap(); + + let (out, _) = + slice_with_warnings(&sql, &["state_fn".into()], "myext", Some("myext")); + assert!(out.contains("'$libdir/myext'"), "missing libdir substitution:\n{out}"); + assert!(!out.contains("'MODULE_PATHNAME'"), "raw placeholder leaked:\n{out}"); + } } diff --git a/skills/cargo-pgrx/SKILL.md b/skills/cargo-pgrx/SKILL.md index fc64bd866..a07dce571 100644 --- a/skills/cargo-pgrx/SKILL.md +++ b/skills/cargo-pgrx/SKILL.md @@ -1,6 +1,6 @@ --- name: cargo-pgrx -description: "cargo pgrx CLI, test discipline (#[test] vs #[pg_test]), the pg_sys linkage boundary, and command routing for pgrx extension development. Use when writing tests, choosing between #[test] and #[pg_test], running builds, or invoking any cargo pgrx subcommand." +description: "Choose and run cargo-pgrx commands, pgrx tests, pg_test coverage, and pg_sys boundary checks." user-invocable: false ---