From 6efd1f0bc5302cbb7d3da6efc8ce6e955bdccb33 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sat, 21 Mar 2026 01:29:15 +0700 Subject: [PATCH 1/6] feat(db): add PostGIS support for Prisma --- Cargo.lock | 1 + psl/parser-database/Cargo.toml | 1 + psl/parser-database/src/attributes/default.rs | 6 + psl/parser-database/src/lib.rs | 4 +- psl/parser-database/src/types.rs | 182 +++++++++++++++--- .../capabilities_support.rs | 2 +- .../postgres_datamodel_connector.rs | 13 +- .../src/datamodel_connector/capabilities.rs | 3 +- .../validation_pipeline/validations.rs | 2 + .../validation_pipeline/validations/fields.rs | 85 +++++++- psl/schema-ast/src/ast.rs | 2 +- psl/schema-ast/src/ast/field.rs | 60 +++++- psl/schema-ast/src/parser/datamodel.pest | 24 ++- psl/schema-ast/src/parser/parse_types.rs | 55 +++++- quaint/.github/workflows/test.yml | 42 ++-- query-compiler/core-tests/Cargo.toml | 2 +- .../tests/geometry_find_many_graph_builds.rs | 53 +++++ query-compiler/core/src/constants.rs | 1 + .../core/src/query_document/parser.rs | 80 ++++---- query-compiler/dmmf/Cargo.toml | 2 +- .../src/ast_builders/datamodel_ast_builder.rs | 12 +- .../schema_ast_builder/type_renderer.rs | 29 +-- query-compiler/dmmf/src/tests/tests.rs | 51 +++++ .../sql-query-builder/src/convert.rs | 1 + .../src/model_extensions/scalar_field.rs | 2 + .../query-compiler/src/data_mapper.rs | 29 ++- .../tests/data/geometry-find-many.json | 11 ++ .../query-compiler/tests/data/schema.prisma | 5 + ...ries__queries@geometry-find-many.json.snap | 11 ++ .../query-structure/src/field/mod.rs | 14 +- .../query-structure/src/field/scalar.rs | 7 +- .../query-structure/src/prisma_value_ext.rs | 4 +- .../src/protocols/json/protocol_adapter.rs | 95 +++++++++ .../fields/data_input_mapper/update.rs | 3 + .../input_types/fields/field_filter_types.rs | 23 ++- .../schema/src/build/input_types/mod.rs | 1 + .../schema/src/build/output_types/field.rs | 1 + query-compiler/schema/src/output_types.rs | 4 + query-compiler/schema/src/query_schema.rs | 5 +- .../src/flavour/postgres/renderer.rs | 11 +- .../src/flavour/postgres/schema_differ.rs | 14 ++ .../src/flavour/sqlite/renderer.rs | 1 + .../introspection_pair/scalar_field.rs | 11 ++ .../src/sql_schema_calculator.rs | 41 +++- .../tests/postgres/mod.rs | 1 + .../tests/postgres/postgis_geometry.rs | 29 +++ .../tests/migrations/postgres.rs | 1 + .../migrations/postgres/postgis_geometry.rs | 88 +++++++++ schema-engine/sql-schema-describer/src/lib.rs | 7 + .../sql-schema-describer/src/mssql.rs | 1 + .../sql-schema-describer/src/mysql.rs | 4 +- .../sql-schema-describer/src/postgres.rs | 117 +++++++++++ .../src/postgres/default.rs | 1 + .../postgres/default/c_style_scalar_lists.rs | 1 + .../sql-schema-describer/src/sqlite.rs | 2 +- 55 files changed, 1107 insertions(+), 151 deletions(-) create mode 100644 query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs create mode 100644 query-compiler/query-compiler/tests/data/geometry-find-many.json create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap create mode 100644 schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs create mode 100644 schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs diff --git a/Cargo.lock b/Cargo.lock index 94b3da853616..780a9952b560 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3123,6 +3123,7 @@ dependencies = [ "itertools 0.14.0", "rustc-hash 2.1.1", "schema-ast", + "serde", ] [[package]] diff --git a/psl/parser-database/Cargo.toml b/psl/parser-database/Cargo.toml index b9a188cd3917..bf49885dfc25 100644 --- a/psl/parser-database/Cargo.toml +++ b/psl/parser-database/Cargo.toml @@ -11,3 +11,4 @@ enumflags2.workspace = true itertools.workspace = true either.workspace = true rustc-hash.workspace = true +serde = { workspace = true, features = ["derive"] } diff --git a/psl/parser-database/src/attributes/default.rs b/psl/parser-database/src/attributes/default.rs index adc29bc9b003..2f9795556e87 100644 --- a/psl/parser-database/src/attributes/default.rs +++ b/psl/parser-database/src/attributes/default.rs @@ -73,6 +73,9 @@ pub(super) fn visit_model_field_default( "Only @default(dbgenerated(\"...\")) can be used for Unsupported types.", ); } + ScalarFieldType::Geometry(_) => { + ctx.push_attribute_validation_error("Only @default(dbgenerated(\"...\")) can be used for Geometry types."); + } } } @@ -139,6 +142,9 @@ pub(super) fn visit_composite_field_default( ScalarFieldType::Unsupported(_) => { ctx.push_attribute_validation_error("Composite field of type `Unsupported` cannot have default values.") } + ScalarFieldType::Geometry(_) => { + ctx.push_attribute_validation_error("Composite field of type `Geometry` cannot have default values.") + } } } diff --git a/psl/parser-database/src/lib.rs b/psl/parser-database/src/lib.rs index c63ed20f9884..21f9f56c120a 100644 --- a/psl/parser-database/src/lib.rs +++ b/psl/parser-database/src/lib.rs @@ -56,8 +56,8 @@ pub use relations::{ManyToManyRelationId, ReferentialAction, RelationId}; use schema_ast::ast::{GeneratorConfig, SourceConfig}; pub use schema_ast::{SourceFile, ast}; pub use types::{ - IndexAlgorithm, IndexFieldPath, IndexType, OperatorClass, RelationFieldId, ScalarFieldId, ScalarFieldType, - ScalarType, SortOrder, WhereClause, WhereCondition, WhereValue, + GeometrySpec, GeometrySubtype, IndexAlgorithm, IndexFieldPath, IndexType, OperatorClass, PostgisSpatialKind, + RelationFieldId, ScalarFieldId, ScalarFieldType, ScalarType, SortOrder, WhereClause, WhereCondition, WhereValue, }; /// ParserDatabase is a container for a Schema AST, together with information diff --git a/psl/parser-database/src/types.rs b/psl/parser-database/src/types.rs index 55b6665a2c50..48a65b10b1d0 100644 --- a/psl/parser-database/src/types.rs +++ b/psl/parser-database/src/types.rs @@ -8,7 +8,11 @@ use either::Either; use enumflags2::bitflags; use rustc_hash::FxHashMap as HashMap; use schema_ast::ast::{self, EnumValueId, WithName}; -use std::{collections::BTreeMap, fmt}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::BTreeMap, + fmt::{self, Write as _}, +}; pub(super) fn resolve_types(ctx: &mut Context<'_>) { for ((file_id, top_id), top) in ctx.iter_tops() { @@ -184,6 +188,93 @@ impl UnsupportedType { } } +/// OGC / PostGIS geometry subtype for [`ScalarFieldType::Geometry`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum GeometrySubtype { + /// `POINT` subtype. + Point, + /// `LINESTRING` subtype. + LineString, + /// `POLYGON` subtype. + Polygon, + /// `MULTIPOINT` subtype. + MultiPoint, + /// `MULTILINESTRING` subtype. + MultiLineString, + /// `MULTIPOLYGON` subtype. + MultiPolygon, + /// `GEOMETRYCOLLECTION` subtype. + GeometryCollection, + /// Unrestricted `GEOMETRY` subtype. + Geometry, +} + +impl GeometrySubtype { + /// PSL spelling of the subtype (e.g. `Point`). + pub fn as_str(self) -> &'static str { + match self { + GeometrySubtype::Point => "Point", + GeometrySubtype::LineString => "LineString", + GeometrySubtype::Polygon => "Polygon", + GeometrySubtype::MultiPoint => "MultiPoint", + GeometrySubtype::MultiLineString => "MultiLineString", + GeometrySubtype::MultiPolygon => "MultiPolygon", + GeometrySubtype::GeometryCollection => "GeometryCollection", + GeometrySubtype::Geometry => "Geometry", + } + } +} + +impl From for GeometrySubtype { + fn from(s: ast::GeometrySubtype) -> Self { + match s { + ast::GeometrySubtype::Point => Self::Point, + ast::GeometrySubtype::LineString => Self::LineString, + ast::GeometrySubtype::Polygon => Self::Polygon, + ast::GeometrySubtype::MultiPoint => Self::MultiPoint, + ast::GeometrySubtype::MultiLineString => Self::MultiLineString, + ast::GeometrySubtype::MultiPolygon => Self::MultiPolygon, + ast::GeometrySubtype::GeometryCollection => Self::GeometryCollection, + ast::GeometrySubtype::Geometry => Self::Geometry, + } + } +} + +/// PostGIS base type for a [`GeometrySpec`] (`geometry` vs `geography` in PostgreSQL). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)] +pub enum PostgisSpatialKind { + /// `geometry(...)` columns (planar). + #[default] + Geometry, + /// `geography(...)` columns (geodetic). + Geography, +} + +/// Parameters for a `Geometry(subtype, srid?)` scalar field type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct GeometrySpec { + /// Geometry subtype (OGC / PostGIS). + pub subtype: GeometrySubtype, + /// Spatial reference ID; `None` when omitted in the schema (distinct from SRID 0 in the database). + pub srid: Option, + /// Whether the physical column uses PostGIS `geometry` or `geography`. + #[serde(default)] + pub spatial: PostgisSpatialKind, +} + +impl GeometrySpec { + /// SQL column type for PostgreSQL / PostGIS (e.g. `geometry(Point,4326)` or `geography(Point,4326)`). + pub fn postgres_sql_type(&self) -> String { + let base = match self.spatial { + PostgisSpatialKind::Geometry => "geometry", + PostgisSpatialKind::Geography => "geography", + }; + let subtype = self.subtype.as_str(); + let srid = self.srid.unwrap_or(0); + format!("{base}({subtype},{srid})") + } +} + /// The type of a scalar field, parsed and categorized. #[derive(Debug, Clone, Copy, PartialEq)] pub enum ScalarFieldType { @@ -195,6 +286,8 @@ pub enum ScalarFieldType { Extension(ExtensionTypeId), /// A Prisma scalar type BuiltInScalar(ScalarType), + /// PostGIS-style `Geometry(Point, 4326)` scalar + Geometry(GeometrySpec), /// An `Unsupported("...")` type Unsupported(UnsupportedType), } @@ -277,6 +370,11 @@ impl ScalarFieldType { matches!(self, Self::BuiltInScalar(ScalarType::Decimal)) } + /// True if the field's type is `Geometry(...)`. + pub fn is_geometry(self) -> bool { + matches!(self, Self::Geometry(_)) + } + /// Display the field type as it would appear in the Prisma schema. pub fn display<'a>(&'a self, db: &'a ParserDatabase) -> impl fmt::Display + 'a { DisplayScalarFieldType { field_type: self, db } @@ -307,6 +405,13 @@ impl fmt::Display for DisplayScalarFieldType<'_> { .expect("extension type id to have a name"); write!(f, "{}", self.db.interner.get(*name).unwrap()) } + ScalarFieldType::Geometry(spec) => { + write!(f, "Geometry({}", spec.subtype.as_str())?; + if let Some(srid) = spec.srid { + write!(f, ", {srid}")?; + } + f.write_char(')') + } ScalarFieldType::Unsupported(ut) => { write!(f, "Unsupported(\"{}\")", self.db.interner.get(ut.name).unwrap()) } @@ -473,6 +578,10 @@ impl IndexAlgorithm { return true; } + if r#type.is_geometry() { + return matches!(self, IndexAlgorithm::BTree | IndexAlgorithm::Gist); + } + match self { IndexAlgorithm::BTree => true, IndexAlgorithm::Hash => true, @@ -836,41 +945,52 @@ fn visit_enum<'db>(enm: &'db ast::Enum, ctx: &mut Context<'db>) { /// Either a structured, supported type, or an Err(unsupported) if the type name /// does not match any we know of. fn field_type<'db>(field: &'db ast::Field, ctx: &mut Context<'db>) -> Result { - let supported = match &field.field_type { - ast::FieldType::Supported(ident) => &ident.name, + match &field.field_type { + ast::FieldType::Geometry { subtype, srid, .. } => { + Ok(FieldType::Scalar(ScalarFieldType::Geometry(GeometrySpec { + subtype: (*subtype).into(), + srid: *srid, + spatial: PostgisSpatialKind::Geometry, + }))) + } ast::FieldType::Unsupported(name, _) => { let unsupported = UnsupportedType::new(ctx.interner.intern(name)); - return Ok(FieldType::Scalar(ScalarFieldType::Unsupported(unsupported))); - } - }; - - if let Some(tpe) = ScalarType::try_from_str(supported, false) { - return Ok(FieldType::Scalar(ScalarFieldType::BuiltInScalar(tpe))); - } - - let supported_string_id = ctx.interner.intern(supported); - match ctx - .names - .tops - .get(&supported_string_id) - .map(|id| (id.0, id.1, &ctx.asts[*id])) - { - Some((file_id, ast::TopId::Model(model_id), ast::Top::Model(_))) => Ok(FieldType::Model((file_id, model_id))), - Some((file_id, ast::TopId::Enum(enum_id), ast::Top::Enum(_))) => { - Ok(FieldType::Scalar(ScalarFieldType::Enum((file_id, enum_id)))) + Ok(FieldType::Scalar(ScalarFieldType::Unsupported(unsupported))) } - Some((file_id, ast::TopId::CompositeType(ctid), ast::Top::CompositeType(_))) => { - Ok(FieldType::Scalar(ScalarFieldType::CompositeType((file_id, ctid)))) - } - Some((_, _, ast::Top::Generator(_))) | Some((_, _, ast::Top::Source(_))) => unreachable!(), - None => { - if let Some(type_id) = ctx.extension_types().get_by_prisma_name(supported) { - Ok(FieldType::Scalar(ScalarFieldType::Extension(type_id))) - } else { - Err(supported) + ast::FieldType::Supported(ident) => { + let supported = ident.name.as_str(); + + if let Some(tpe) = ScalarType::try_from_str(supported, false) { + return Ok(FieldType::Scalar(ScalarFieldType::BuiltInScalar(tpe))); + } + + let supported_string_id = ctx.interner.intern(supported); + match ctx + .names + .tops + .get(&supported_string_id) + .map(|id| (id.0, id.1, &ctx.asts[*id])) + { + Some((file_id, ast::TopId::Model(model_id), ast::Top::Model(_))) => { + Ok(FieldType::Model((file_id, model_id))) + } + Some((file_id, ast::TopId::Enum(enum_id), ast::Top::Enum(_))) => { + Ok(FieldType::Scalar(ScalarFieldType::Enum((file_id, enum_id)))) + } + Some((file_id, ast::TopId::CompositeType(ctid), ast::Top::CompositeType(_))) => { + Ok(FieldType::Scalar(ScalarFieldType::CompositeType((file_id, ctid)))) + } + Some((_, _, ast::Top::Generator(_))) | Some((_, _, ast::Top::Source(_))) => unreachable!(), + None => { + if let Some(type_id) = ctx.extension_types().get_by_prisma_name(supported) { + Ok(FieldType::Scalar(ScalarFieldType::Extension(type_id))) + } else { + Err(supported) + } + } + _ => unreachable!(), } } - _ => unreachable!(), } } diff --git a/psl/psl-core/src/builtin_connectors/capabilities_support.rs b/psl/psl-core/src/builtin_connectors/capabilities_support.rs index aef2d22cd076..5b666a8c7b8d 100644 --- a/psl/psl-core/src/builtin_connectors/capabilities_support.rs +++ b/psl/psl-core/src/builtin_connectors/capabilities_support.rs @@ -83,7 +83,7 @@ macro_rules! reachable_only_with_capability { #[inline(always)] #[allow(dead_code)] // not used if more than one connector is built const fn check_comptime_capability(capabilities: ConnectorCapabilities, cap: ConnectorCapability) -> bool { - (capabilities.bits_c() & (cap as u64)) > 0 + (capabilities.bits_c() & (cap as u128)) > 0 } #[inline(always)] diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs index 6398275b336d..93c52ecd9221 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs @@ -3,7 +3,7 @@ mod native_types; mod validations; pub use native_types::{KnownPostgresType, PostgresType}; -use parser_database::{ExtensionTypes, ScalarFieldType}; +use parser_database::{ExtensionTypes, GeometrySpec, ScalarFieldType}; use crate::{ Configuration, Datasource, DatasourceConnectorData, PreviewFeature, ValidatedSchema, @@ -74,7 +74,8 @@ pub const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Conne SupportsFiltersOnRelationsWithoutJoins | LateralJoin | SupportsDefaultInInsert | - PartialIndex + PartialIndex | + PostgisGeometry }); pub struct PostgresDatamodelConnector; @@ -82,6 +83,10 @@ pub struct PostgresDatamodelConnector; const DATE_TIME_DEFAULT: KnownPostgresType = KnownPostgresType::Timestamp(Some(3)); const BYTES_DEFAULT: KnownPostgresType = KnownPostgresType::ByteA; +fn geometry_sql_column_type(spec: &GeometrySpec) -> String { + spec.postgres_sql_type() +} + const SCALAR_TYPE_DEFAULTS: &[(ScalarType, KnownPostgresType)] = &[ (ScalarType::Int, KnownPostgresType::Integer), (ScalarType::BigInt, KnownPostgresType::BigInt), @@ -374,6 +379,10 @@ impl Connector for PostgresDatamodelConnector { let native_type = PostgresType::Unknown(name.to_owned(), modifiers.to_vec()); return Some(NativeTypeInstance::new::(native_type)); } + ScalarFieldType::Geometry(spec) => { + let native_type = PostgresType::Unknown(geometry_sql_column_type(spec), Vec::new()); + return Some(NativeTypeInstance::new::(native_type)); + } ScalarFieldType::CompositeType(_) | ScalarFieldType::Enum(_) | ScalarFieldType::Unsupported(_) => { return None; } diff --git a/psl/psl-core/src/datamodel_connector/capabilities.rs b/psl/psl-core/src/datamodel_connector/capabilities.rs index 8cc67a244726..ef52ea53adf0 100644 --- a/psl/psl-core/src/datamodel_connector/capabilities.rs +++ b/psl/psl-core/src/datamodel_connector/capabilities.rs @@ -6,7 +6,7 @@ macro_rules! capabilities { ($( $variant:ident $(,)? ),*) => { #[derive(Debug, Clone, Copy, PartialEq)] #[enumflags2::bitflags] - #[repr(u64)] + #[repr(u128)] pub enum ConnectorCapability { $( $variant, @@ -93,6 +93,7 @@ capabilities!( AdvancedJsonNullability, // Connector distinguishes between their null type and JSON null. UndefinedType, // Connector distinguishes `null` and `undefined` DecimalType, // Connector supports Prisma Decimal type. + PostgisGeometry, // Connector supports first-class `Geometry(...)` (PostgreSQL / PostGIS). BackwardCompatibleQueryRaw, // Temporary SQLite specific capability. Should be removed once https://github.com/prisma/prisma/issues/12784 is fixed, OrderByNullsFirstLast, // Connector supports ORDER BY NULLS LAST/FIRST FilteredInlineChildNestedToOneDisconnect, // Connector supports a filtered nested disconnect on both sides of a to-one relation. diff --git a/psl/psl-core/src/validate/validation_pipeline/validations.rs b/psl/psl-core/src/validate/validation_pipeline/validations.rs index 9e19b13b2d8b..1bf625e54eb1 100644 --- a/psl/psl-core/src/validate/validation_pipeline/validations.rs +++ b/psl/psl-core/src/validate/validation_pipeline/validations.rs @@ -32,6 +32,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { for field in composite_type.fields() { composite_types::validate_default_value(field, ctx); fields::validate_native_type_arguments(field, ctx); + fields::validate_geometry_on_composite_field(field, ctx); } } } @@ -92,6 +93,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { for field in model.scalar_fields() { fields::validate_scalar_field_connector_specific(field, ctx); + fields::validate_geometry_field(field, ctx); fields::validate_client_name(field.into(), &names, ctx); fields::has_a_unique_default_constraint_name(field, &names, ctx); fields::validate_native_type_arguments(field, ctx); diff --git a/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs b/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs index 67b01495500b..a82f3fbc2273 100644 --- a/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs +++ b/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs @@ -10,9 +10,12 @@ use crate::datamodel_connector::{ConnectorCapability, NativeTypeConstructor, wal use crate::{diagnostics::DatamodelError, validate::validation_pipeline::context::Context}; use itertools::Itertools; use parser_database::{ - ScalarFieldType, ScalarType, + GeometrySpec, ScalarFieldType, ScalarType, ast::{self, WithSpan}, - walkers::{FieldWalker, PrimaryKeyWalker, ScalarFieldAttributeWalker, ScalarFieldWalker, TypedFieldWalker}, + walkers::{ + CompositeTypeFieldWalker, FieldWalker, PrimaryKeyWalker, ScalarFieldAttributeWalker, ScalarFieldWalker, + TypedFieldWalker, + }, }; pub(super) fn validate_client_name(field: FieldWalker<'_>, names: &Names<'_>, ctx: &mut Context<'_>) { @@ -328,6 +331,84 @@ pub(super) fn validate_scalar_field_connector_specific(field: ScalarFieldWalker< } } +fn validate_geometry_spec_constraints( + spec: GeometrySpec, + ctx: &mut Context<'_>, + container: &str, + container_name: &str, + field_name: &str, + field_span: ast::Span, + type_span: ast::Span, +) { + if !ctx.has_capability(ConnectorCapability::PostgisGeometry) { + let msg = format!( + "Field `{field_name}` in {container} `{container_name}` uses type Geometry, which is only supported on PostgreSQL with PostGIS.", + ); + if container == "composite type" { + ctx.push_error(DatamodelError::new_composite_type_validation_error( + &msg, + container_name, + field_span, + )); + } else { + ctx.push_error(DatamodelError::new_field_validation_error( + &msg, + container, + container_name, + field_name, + field_span, + )); + } + } + + if let Some(srid) = spec.srid + && (srid < 0 || srid > 999_999) + { + ctx.push_error(DatamodelError::new_validation_error( + &format!("Invalid SRID {srid}. Must be between 0 and 999999 when specified."), + type_span, + )); + } +} + +pub(super) fn validate_geometry_field(field: ScalarFieldWalker<'_>, ctx: &mut Context<'_>) { + let ScalarFieldType::Geometry(spec) = field.scalar_field_type() else { + return; + }; + + let container = if field.model().ast_model().is_view() { + "view" + } else { + "model" + }; + + validate_geometry_spec_constraints( + spec, + ctx, + container, + field.model().name(), + field.name(), + field.ast_field().span(), + field.ast_field().field_type.span(), + ); +} + +pub(super) fn validate_geometry_on_composite_field(field: CompositeTypeFieldWalker<'_>, ctx: &mut Context<'_>) { + let ScalarFieldType::Geometry(spec) = field.r#type() else { + return; + }; + + validate_geometry_spec_constraints( + spec, + ctx, + "composite type", + field.composite_type().name(), + field.name(), + field.ast_field().span(), + field.ast_field().field_type.span(), + ); +} + pub(super) fn validate_unsupported_field_type(field: ScalarFieldWalker<'_>, ctx: &mut Context<'_>) { use regex::Regex; diff --git a/psl/schema-ast/src/ast.rs b/psl/schema-ast/src/ast.rs index 8158e30f73a2..f762a34d9158 100644 --- a/psl/schema-ast/src/ast.rs +++ b/psl/schema-ast/src/ast.rs @@ -25,7 +25,7 @@ pub use config::ConfigBlockProperty; pub use diagnostics::Span; pub use r#enum::{Enum, EnumValue, EnumValueId}; pub use expression::{Expression, ObjectMember}; -pub use field::{Field, FieldArity, FieldType}; +pub use field::{Field, FieldArity, FieldType, GeometrySubtype}; pub use find_at_position::*; pub use generator_config::GeneratorConfig; pub use identifier::Identifier; diff --git a/psl/schema-ast/src/ast/field.rs b/psl/schema-ast/src/ast/field.rs index 394381a2f3b1..c3ba2f131fd4 100644 --- a/psl/schema-ast/src/ast/field.rs +++ b/psl/schema-ast/src/ast/field.rs @@ -1,9 +1,38 @@ -use std::fmt::Display; +use std::fmt::{Display, Write}; use super::{ Attribute, Comment, Identifier, Span, WithAttributes, WithDocumentation, WithIdentifier, WithName, WithSpan, }; +/// OGC / PostGIS geometry subtype written in `Geometry(...)` field types. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum GeometrySubtype { + Point, + LineString, + Polygon, + MultiPoint, + MultiLineString, + MultiPolygon, + GeometryCollection, + Geometry, +} + +impl GeometrySubtype { + /// PSL spelling of the subtype (e.g. `Point`). + pub fn as_str(self) -> &'static str { + match self { + GeometrySubtype::Point => "Point", + GeometrySubtype::LineString => "LineString", + GeometrySubtype::Polygon => "Polygon", + GeometrySubtype::MultiPoint => "MultiPoint", + GeometrySubtype::MultiLineString => "MultiLineString", + GeometrySubtype::MultiPolygon => "MultiPolygon", + GeometrySubtype::GeometryCollection => "GeometryCollection", + GeometrySubtype::Geometry => "Geometry", + } + } +} + /// A field definition in a model or a composite type. #[derive(Debug, Clone)] pub struct Field { @@ -52,7 +81,7 @@ impl Display for Field { "" }; - write!(f, "{} {}{}", self.name(), self.field_type.name(), extension) + write!(f, "{} {}{}", self.name(), self.field_type, extension) } } @@ -149,6 +178,12 @@ impl FieldArity { #[derive(Debug, Clone, PartialEq)] pub enum FieldType { Supported(Identifier), + /// `Geometry(Point, 4326)` or `Geometry(LineString)` (SRID optional). + Geometry { + subtype: GeometrySubtype, + srid: Option, + span: Span, + }, /// Unsupported("...") Unsupported(String, Span), } @@ -157,6 +192,7 @@ impl FieldType { pub fn span(&self) -> Span { match self { FieldType::Supported(ident) => ident.span, + FieldType::Geometry { span, .. } => *span, FieldType::Unsupported(_, span) => *span, } } @@ -164,6 +200,7 @@ impl FieldType { pub fn name(&self) -> &str { match self { FieldType::Supported(supported) => &supported.name, + FieldType::Geometry { .. } => "Geometry", FieldType::Unsupported(name, _) => name, } } @@ -171,7 +208,24 @@ impl FieldType { pub fn as_unsupported(&self) -> Option<(&str, &Span)> { match self { FieldType::Unsupported(name, span) => Some((name, span)), - FieldType::Supported(_) => None, + FieldType::Supported(_) | FieldType::Geometry { .. } => None, + } + } +} + +impl Display for FieldType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FieldType::Supported(ident) => f.write_str(&ident.name), + FieldType::Geometry { subtype, srid, .. } => { + f.write_str("Geometry(")?; + f.write_str(subtype.as_str())?; + if let Some(srid) = srid { + write!(f, ", {srid}")?; + } + f.write_char(')') + } + FieldType::Unsupported(name, _) => write!(f, "Unsupported({})", crate::string_literal(name)), } } } diff --git a/psl/schema-ast/src/parser/datamodel.pest b/psl/schema-ast/src/parser/datamodel.pest index e9c91aebdc29..266105cbcf6d 100644 --- a/psl/schema-ast/src/parser/datamodel.pest +++ b/psl/schema-ast/src/parser/datamodel.pest @@ -55,7 +55,29 @@ model_contents = { field_type = { unsupported_optional_list_type | list_type | optional_type | legacy_required_type | legacy_list_type | base_type } unsupported_type = { "Unsupported(" ~ string_literal ~ ")" } -base_type = { unsupported_type | identifier } // Called base type to not conflict with type rust keyword + +geometry_subtype = { + "MultiPoint" + | "MultiLineString" + | "MultiPolygon" + | "GeometryCollection" + | "Point" + | "LineString" + | "Polygon" + | "Geometry" +} + +geometry_srid = @{ ASCII_DIGIT+ } + +geometry_type = { + "Geometry" + ~ "(" + ~ geometry_subtype + ~ ( "," ~ geometry_srid )? + ~ ")" +} + +base_type = { unsupported_type | geometry_type | identifier } // Called base type to not conflict with type rust keyword unsupported_optional_list_type = { base_type ~ "[]" ~ "?" } list_type = { base_type ~ "[]" } optional_type = { base_type ~ "?" } diff --git a/psl/schema-ast/src/parser/parse_types.rs b/psl/schema-ast/src/parser/parse_types.rs index d70056aab56b..5fe91559a829 100644 --- a/psl/schema-ast/src/parser/parse_types.rs +++ b/psl/schema-ast/src/parser/parse_types.rs @@ -12,12 +12,12 @@ pub fn parse_field_type( match current.as_rule() { Rule::optional_type => Ok(( FieldArity::Optional, - parse_base_type(current.into_inner().next().unwrap(), diagnostics, file_id), + parse_base_type(current.into_inner().next().unwrap(), diagnostics, file_id)?, )), - Rule::base_type => Ok((FieldArity::Required, parse_base_type(current, diagnostics, file_id))), + Rule::base_type => Ok((FieldArity::Required, parse_base_type(current, diagnostics, file_id)?)), Rule::list_type => Ok(( FieldArity::List, - parse_base_type(current.into_inner().next().unwrap(), diagnostics, file_id), + parse_base_type(current.into_inner().next().unwrap(), diagnostics, file_id)?, )), Rule::legacy_required_type => Err(DatamodelError::new_legacy_parser_error( "Fields are required by default, `!` is no longer required.", @@ -35,17 +35,58 @@ pub fn parse_field_type( } } -fn parse_base_type(pair: Pair<'_>, diagnostics: &mut Diagnostics, file_id: FileId) -> FieldType { +fn parse_base_type( + pair: Pair<'_>, + diagnostics: &mut Diagnostics, + file_id: FileId, +) -> Result { let current = pair.into_inner().next().unwrap(); match current.as_rule() { - Rule::identifier => FieldType::Supported(Identifier { + Rule::identifier => Ok(FieldType::Supported(Identifier { name: current.as_str().to_string(), span: Span::from((file_id, current.as_span())), - }), + })), Rule::unsupported_type => match parse_expression(current, diagnostics, file_id) { - Expression::StringValue(lit, span) => FieldType::Unsupported(lit, span), + Expression::StringValue(lit, span) => Ok(FieldType::Unsupported(lit, span)), _ => unreachable!("Encountered impossible type during parsing"), }, + Rule::geometry_type => parse_geometry_type(current, file_id), _ => unreachable!("Encountered impossible type during parsing: {:?}", current.tokens()), } } + +fn parse_geometry_type(pair: Pair<'_>, file_id: FileId) -> Result { + let span = Span::from((file_id, pair.as_span())); + let mut inner = pair.into_inner(); + let subtype_pair = inner.next().expect("geometry: subtype"); + debug_assert_eq!(subtype_pair.as_rule(), Rule::geometry_subtype); + let subtype = match subtype_pair.as_str() { + "Point" => crate::ast::GeometrySubtype::Point, + "LineString" => crate::ast::GeometrySubtype::LineString, + "Polygon" => crate::ast::GeometrySubtype::Polygon, + "MultiPoint" => crate::ast::GeometrySubtype::MultiPoint, + "MultiLineString" => crate::ast::GeometrySubtype::MultiLineString, + "MultiPolygon" => crate::ast::GeometrySubtype::MultiPolygon, + "GeometryCollection" => crate::ast::GeometrySubtype::GeometryCollection, + "Geometry" => crate::ast::GeometrySubtype::Geometry, + _ => unreachable!("geometry_subtype rule produced unexpected token"), + }; + + let srid = if let Some(srid_pair) = inner.next() { + debug_assert_eq!(srid_pair.as_rule(), Rule::geometry_srid); + let raw = srid_pair.as_str(); + match raw.parse::() { + Ok(v) => Some(v), + Err(_) => { + return Err(DatamodelError::new_validation_error( + "Invalid SRID: expected a valid 32-bit integer.", + (file_id, srid_pair.as_span()).into(), + )); + } + } + } else { + None + }; + + Ok(FieldType::Geometry { subtype, srid, span }) +} diff --git a/quaint/.github/workflows/test.yml b/quaint/.github/workflows/test.yml index dd9a1b8e96c4..cb8cf5a88cf2 100644 --- a/quaint/.github/workflows/test.yml +++ b/quaint/.github/workflows/test.yml @@ -8,14 +8,14 @@ jobs: clippy: runs-on: ubuntu-latest env: - RUSTFLAGS: "-Dwarnings" + RUSTFLAGS: '-Dwarnings' steps: - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: - components: clippy - override: true - toolchain: stable + components: clippy + override: true + toolchain: stable - name: Install dependencies run: sudo apt install -y openssl libkrb5-dev - uses: actions-rs/clippy-check@v1 @@ -44,24 +44,24 @@ jobs: fail-fast: false matrix: features: - - "--lib --features=all" - - "--lib --no-default-features --features=sqlite" - - "--lib --no-default-features --features=sqlite --features=pooled" - - "--lib --no-default-features --features=postgresql" - - "--lib --no-default-features --features=postgresql --features=pooled" - - "--lib --no-default-features --features=mysql" - - "--lib --no-default-features --features=mysql --features=pooled" - - "--lib --no-default-features --features=mssql" - - "--lib --no-default-features --features=mssql --features=pooled" - - "--doc --features=all" + - '--lib --features=all' + - '--lib --no-default-features --features=sqlite' + - '--lib --no-default-features --features=sqlite --features=pooled' + - '--lib --no-default-features --features=postgresql' + - '--lib --no-default-features --features=postgresql --features=pooled' + - '--lib --no-default-features --features=mysql' + - '--lib --no-default-features --features=mysql --features=pooled' + - '--lib --no-default-features --features=mssql' + - '--lib --no-default-features --features=mssql --features=pooled' + - '--doc --features=all' env: - TEST_MYSQL: "mysql://root:prisma@localhost:3306/prisma" - TEST_MYSQL8: "mysql://root:prisma@localhost:3307/prisma" - TEST_MYSQL_MARIADB: "mysql://root:prisma@localhost:3308/prisma" - TEST_PSQL: "postgres://postgres:prisma@localhost:5432/postgres" - TEST_MSSQL: "jdbc:sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true" - TEST_CRDB: "postgresql://prisma@127.0.0.1:26259/postgres" - RUSTFLAGS: "-Dwarnings" + TEST_MYSQL: 'mysql://root:prisma@localhost:3306/prisma' + TEST_MYSQL8: 'mysql://root:prisma@localhost:3307/prisma' + TEST_MYSQL_MARIADB: 'mysql://root:prisma@localhost:3308/prisma' + TEST_PSQL: 'postgres://postgres:prisma@localhost:5432/postgres' + TEST_MSSQL: 'jdbc:sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true' + TEST_CRDB: 'postgresql://prisma@127.0.0.1:26259/postgres' + RUSTFLAGS: '-Dwarnings' steps: - uses: actions/checkout@v4 diff --git a/query-compiler/core-tests/Cargo.toml b/query-compiler/core-tests/Cargo.toml index 35f223636e12..83dada64a1f1 100644 --- a/query-compiler/core-tests/Cargo.toml +++ b/query-compiler/core-tests/Cargo.toml @@ -12,7 +12,7 @@ user-facing-errors.workspace = true request-handlers = { workspace = true, features = ["all"] } query-core.workspace = true schema.workspace = true -psl.workspace = true +psl = { workspace = true, features = ["postgresql"] } serde_json.workspace = true [[bench]] diff --git a/query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs b/query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs new file mode 100644 index 000000000000..3aa8dcb3a896 --- /dev/null +++ b/query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs @@ -0,0 +1,53 @@ +use std::sync::Arc; + +use query_core::{QueryDocument, QueryGraphBuilder}; +use request_handlers::{JsonBody, JsonSingleQuery, RequestBody}; + +#[test] +fn geometry_find_many_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + generator client { + provider = "prisma-client" + previewFeatures = ["relationJoins"] + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326) + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": {}, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with geometry fields should compile to a query graph"); +} diff --git a/query-compiler/core/src/constants.rs b/query-compiler/core/src/constants.rs index 2ec2a7680060..1b51c7a6d284 100644 --- a/query-compiler/core/src/constants.rs +++ b/query-compiler/core/src/constants.rs @@ -8,6 +8,7 @@ pub mod custom_types { pub const BIGINT: &str = "BigInt"; pub const DECIMAL: &str = "Decimal"; pub const BYTES: &str = "Bytes"; + pub const GEOMETRY: &str = "Geometry"; pub const JSON: &str = "Json"; pub const ENUM: &str = "Enum"; pub const FIELD_REF: &str = "FieldRef"; diff --git a/query-compiler/core/src/query_document/parser.rs b/query-compiler/core/src/query_document/parser.rs index 2a7e147a4e1d..45dbb77aa709 100644 --- a/query-compiler/core/src/query_document/parser.rs +++ b/query-compiler/core/src/query_document/parser.rs @@ -313,7 +313,7 @@ impl QueryDocumentParser { ))), // Scalar handling (pv, InputType::Scalar(st)) => try_this!( - self.parse_scalar(&selection_path, &argument_path, pv, *st, &value, is_parameterizable) + self.parse_scalar(&selection_path, &argument_path, pv, st, &value, is_parameterizable) .map(ParsedInputValue::Single) ), @@ -407,48 +407,51 @@ impl QueryDocumentParser { selection_path: &Path, argument_path: &Path, value: PrismaValue, - scalar_type: ScalarType, + scalar_type: &ScalarType, argument_value: &ArgumentValue, is_parameterizable: bool, ) -> QueryParserResult { match (value, scalar_type) { // Identity matchers - (PrismaValue::String(s), ScalarType::String) => Ok(PrismaValue::String(s)), - (PrismaValue::Boolean(b), ScalarType::Boolean) => Ok(PrismaValue::Boolean(b)), - (PrismaValue::Json(json), ScalarType::Json) => Ok(PrismaValue::Json(json)), - (PrismaValue::Uuid(uuid), ScalarType::UUID) => Ok(PrismaValue::Uuid(uuid)), - (PrismaValue::Bytes(bytes), ScalarType::Bytes) => Ok(PrismaValue::Bytes(bytes)), - (PrismaValue::BigInt(b_int), ScalarType::BigInt) => Ok(PrismaValue::BigInt(b_int)), - (PrismaValue::DateTime(s), ScalarType::DateTime) => Ok(PrismaValue::DateTime(s)), - (PrismaValue::Null, ScalarType::Null) => Ok(PrismaValue::Null), + (PrismaValue::String(s), &ScalarType::String) => Ok(PrismaValue::String(s)), + (PrismaValue::Boolean(b), &ScalarType::Boolean) => Ok(PrismaValue::Boolean(b)), + (PrismaValue::Json(json), &ScalarType::Json) => Ok(PrismaValue::Json(json)), + (PrismaValue::Uuid(uuid), &ScalarType::UUID) => Ok(PrismaValue::Uuid(uuid)), + (PrismaValue::Bytes(bytes), &ScalarType::Bytes) => Ok(PrismaValue::Bytes(bytes)), + (pv @ PrismaValue::Bytes(_), &ScalarType::Geometry(_)) => Ok(pv), + (pv @ PrismaValue::String(_), &ScalarType::Geometry(_)) => Ok(pv), + (PrismaValue::Json(s), &ScalarType::Geometry(_)) => Ok(PrismaValue::Bytes(s.into_bytes())), + (PrismaValue::BigInt(b_int), &ScalarType::BigInt) => Ok(PrismaValue::BigInt(b_int)), + (PrismaValue::DateTime(s), &ScalarType::DateTime) => Ok(PrismaValue::DateTime(s)), + (PrismaValue::Null, &ScalarType::Null) => Ok(PrismaValue::Null), // String coercion matchers - (PrismaValue::String(s), ScalarType::JsonList) => { + (PrismaValue::String(s), &ScalarType::JsonList) => { self.parse_json_list_from_str(selection_path, argument_path, &s) } - (PrismaValue::String(s), ScalarType::Bytes) => self.parse_bytes(selection_path, argument_path, s), - (PrismaValue::String(s), ScalarType::Decimal) => self.parse_decimal(selection_path, argument_path, s), - (PrismaValue::String(s), ScalarType::BigInt) => self.parse_bigint(selection_path, argument_path, s), - (PrismaValue::String(s), ScalarType::UUID) => self + (PrismaValue::String(s), &ScalarType::Bytes) => self.parse_bytes(selection_path, argument_path, s), + (PrismaValue::String(s), &ScalarType::Decimal) => self.parse_decimal(selection_path, argument_path, s), + (PrismaValue::String(s), &ScalarType::BigInt) => self.parse_bigint(selection_path, argument_path, s), + (PrismaValue::String(s), &ScalarType::UUID) => self .parse_uuid(selection_path, argument_path, s.as_str()) .map(PrismaValue::Uuid), - (PrismaValue::String(s), ScalarType::Json) => Ok(PrismaValue::Json( + (PrismaValue::String(s), &ScalarType::Json) => Ok(PrismaValue::Json( self.parse_json(selection_path, argument_path, &s).map(|_| s)?, )), - (PrismaValue::String(s), ScalarType::DateTime) => self + (PrismaValue::String(s), &ScalarType::DateTime) => self .parse_datetime(selection_path, argument_path, s.as_str()) .map(PrismaValue::DateTime), // Int coercion matchers - (PrismaValue::Int(i), ScalarType::Int) => Ok(PrismaValue::Int(i)), - (PrismaValue::Int(i), ScalarType::Float) => Ok(PrismaValue::Float(BigDecimal::from(i))), - (PrismaValue::Int(i), ScalarType::Decimal) => Ok(PrismaValue::Float(BigDecimal::from(i))), - (PrismaValue::Int(i), ScalarType::BigInt) => Ok(PrismaValue::BigInt(i)), + (PrismaValue::Int(i), &ScalarType::Int) => Ok(PrismaValue::Int(i)), + (PrismaValue::Int(i), &ScalarType::Float) => Ok(PrismaValue::Float(BigDecimal::from(i))), + (PrismaValue::Int(i), &ScalarType::Decimal) => Ok(PrismaValue::Float(BigDecimal::from(i))), + (PrismaValue::Int(i), &ScalarType::BigInt) => Ok(PrismaValue::BigInt(i)), // Float coercion matchers - (PrismaValue::Float(f), ScalarType::Float) => Ok(PrismaValue::Float(f)), - (PrismaValue::Float(f), ScalarType::Decimal) => Ok(PrismaValue::Float(f)), - (PrismaValue::Float(f), ScalarType::Int) => match f.to_i64() { + (PrismaValue::Float(f), &ScalarType::Float) => Ok(PrismaValue::Float(f)), + (PrismaValue::Float(f), &ScalarType::Decimal) => Ok(PrismaValue::Float(f)), + (PrismaValue::Float(f), &ScalarType::Int) => match f.to_i64() { Some(converted) => Ok(PrismaValue::Int(converted)), None => Err(ValidationError::value_too_large( selection_path.segments(), @@ -458,7 +461,7 @@ impl QueryDocumentParser { }, // UUID coercion matchers - (PrismaValue::Uuid(uuid), ScalarType::String) => Ok(PrismaValue::String(uuid.to_string())), + (PrismaValue::Uuid(uuid), &ScalarType::String) => Ok(PrismaValue::String(uuid.to_string())), // Generator calls cannot be encoded in the JSON protocol and can // only be injected by the query parser when evaluating the default @@ -507,7 +510,7 @@ impl QueryDocumentParser { (_, _) => Err(invalid_argument_type_error( selection_path, argument_path, - &InputType::Scalar(scalar_type), + &InputType::Scalar(scalar_type.clone()), argument_value, )), } @@ -722,7 +725,7 @@ impl QueryDocumentParser { }; let element_type_matches = match element_input_type { - InputType::Scalar(scalar_type) => prisma_value_type_matches_scalar_type(inner_type, *scalar_type), + InputType::Scalar(scalar_type) => prisma_value_type_matches_scalar_type(inner_type, scalar_type), InputType::Enum(_) => matches!(**inner_type, PrismaValueType::Enum), InputType::List(_) | InputType::Object(_) => { return Err(ValidationError::unexpected_runtime_error( @@ -957,22 +960,27 @@ impl QueryDocumentParser { } } -fn prisma_value_type_matches_scalar_type(pv_type: &PrismaValueType, scalar_type: ScalarType) -> bool { +fn prisma_value_type_matches_scalar_type(pv_type: &PrismaValueType, scalar_type: &ScalarType) -> bool { match pv_type { - PrismaValueType::String => matches!(scalar_type, ScalarType::String | ScalarType::UUID), - PrismaValueType::Boolean => scalar_type == ScalarType::Boolean, - PrismaValueType::Enum => scalar_type == ScalarType::String, + PrismaValueType::String => { + matches!( + scalar_type, + ScalarType::String | ScalarType::UUID | ScalarType::Geometry(_) + ) + } + PrismaValueType::Boolean => scalar_type == &ScalarType::Boolean, + PrismaValueType::Enum => scalar_type == &ScalarType::String, PrismaValueType::Int => matches!(scalar_type, ScalarType::Int | ScalarType::BigInt | ScalarType::Float), PrismaValueType::Uuid => matches!(scalar_type, ScalarType::UUID | ScalarType::String), PrismaValueType::List(prisma_value_type) => { matches!(**prisma_value_type, PrismaValueType::Json | PrismaValueType::Object) - && scalar_type == ScalarType::JsonList + && scalar_type == &ScalarType::JsonList } - PrismaValueType::Json | PrismaValueType::Object => scalar_type == ScalarType::Json, - PrismaValueType::DateTime => scalar_type == ScalarType::DateTime, + PrismaValueType::Json | PrismaValueType::Object => scalar_type == &ScalarType::Json, + PrismaValueType::DateTime => scalar_type == &ScalarType::DateTime, PrismaValueType::Float => matches!(scalar_type, ScalarType::Float | ScalarType::Decimal), - PrismaValueType::BigInt => scalar_type == ScalarType::BigInt, - PrismaValueType::Bytes => scalar_type == ScalarType::Bytes, + PrismaValueType::BigInt => scalar_type == &ScalarType::BigInt, + PrismaValueType::Bytes => matches!(scalar_type, ScalarType::Bytes | ScalarType::Geometry(_)), PrismaValueType::Any => true, } } diff --git a/query-compiler/dmmf/Cargo.toml b/query-compiler/dmmf/Cargo.toml index ff828f64732f..fab7ed2e7cd2 100644 --- a/query-compiler/dmmf/Cargo.toml +++ b/query-compiler/dmmf/Cargo.toml @@ -6,7 +6,7 @@ edition.workspace = true [dependencies] bigdecimal.workspace = true itertools.workspace = true -psl.workspace = true +psl = { workspace = true, features = ["all"] } serde.workspace = true serde_json.workspace = true schema.workspace = true diff --git a/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs b/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs index b8c7bcd0f432..31b5011ae309 100644 --- a/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs +++ b/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs @@ -5,11 +5,15 @@ use crate::serialization_ast::{ use bigdecimal::ToPrimitive; use itertools::{Either, Itertools}; use psl::{ - parser_database::{ScalarFieldType, walkers}, + parser_database::{GeometrySpec, ScalarFieldType, walkers}, schema_ast::ast::WithDocumentation, }; use query_structure::{DefaultKind, FieldArity, PrismaValue, dml_default_kind, encode_bytes}; +fn geometry_dmmf_field_type(spec: &GeometrySpec) -> String { + spec.postgres_sql_type() +} + pub(crate) fn schema_to_dmmf(schema: &psl::ValidatedSchema) -> Datamodel { let mut datamodel = Datamodel { models: Vec::with_capacity(schema.db.models_count()), @@ -86,7 +90,7 @@ fn composite_type_field_to_dmmf(field: walkers::CompositeTypeFieldWalker<'_>) -> kind: match field.r#type() { ScalarFieldType::CompositeType(_) => "object", ScalarFieldType::Enum(_) => "enum", - ScalarFieldType::BuiltInScalar(_) => "scalar", + ScalarFieldType::BuiltInScalar(_) | ScalarFieldType::Geometry(_) => "scalar", ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, db_name: field.mapped_name().map(ToOwned::to_owned), @@ -111,6 +115,7 @@ fn composite_type_field_to_dmmf(field: walkers::CompositeTypeFieldWalker<'_>) -> ScalarFieldType::CompositeType(ct) => field.walk(ct).name().to_owned(), ScalarFieldType::Enum(enm) => field.walk(enm).name().to_owned(), ScalarFieldType::BuiltInScalar(st) => st.as_str().to_owned(), + ScalarFieldType::Geometry(spec) => geometry_dmmf_field_type(&spec), ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, is_generated: None, @@ -181,7 +186,7 @@ fn scalar_field_to_dmmf(field: walkers::ScalarFieldWalker<'_>) -> Field { kind: match field.scalar_field_type() { ScalarFieldType::CompositeType(_) => "object", ScalarFieldType::Enum(_) => "enum", - ScalarFieldType::BuiltInScalar(_) => "scalar", + ScalarFieldType::BuiltInScalar(_) | ScalarFieldType::Geometry(_) => "scalar", ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, is_list: ast_field.arity.is_list(), @@ -199,6 +204,7 @@ fn scalar_field_to_dmmf(field: walkers::ScalarFieldWalker<'_>) -> Field { ScalarFieldType::CompositeType(ct) => field_walker.walk(ct).name().to_owned(), ScalarFieldType::Enum(enm) => field_walker.walk(enm).name().to_owned(), ScalarFieldType::BuiltInScalar(st) => st.as_str().to_owned(), + ScalarFieldType::Geometry(spec) => geometry_dmmf_field_type(&spec), ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, native_type: field diff --git a/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs b/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs index cdc0440aeb70..52736c6da94a 100644 --- a/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs +++ b/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs @@ -33,23 +33,24 @@ pub(super) fn render_output_type<'a>(output_type: &OutputType<'a>, ctx: &mut Ren } InnerOutputType::Scalar(scalar) => { - let stringified = match scalar { - ScalarType::Null => "Null", - ScalarType::String => "String", - ScalarType::Int => "Int", - ScalarType::BigInt => "BigInt", - ScalarType::Boolean => "Boolean", - ScalarType::Float => "Float", - ScalarType::Decimal => "Decimal", - ScalarType::DateTime => "DateTime", - ScalarType::Json => "Json", - ScalarType::UUID => "UUID", - ScalarType::JsonList => "Json", - ScalarType::Bytes => "Bytes", + let stringified: String = match scalar { + ScalarType::Null => "Null".into(), + ScalarType::String => "String".into(), + ScalarType::Int => "Int".into(), + ScalarType::BigInt => "BigInt".into(), + ScalarType::Boolean => "Boolean".into(), + ScalarType::Float => "Float".into(), + ScalarType::Decimal => "Decimal".into(), + ScalarType::DateTime => "DateTime".into(), + ScalarType::Json => "Json".into(), + ScalarType::UUID => "UUID".into(), + ScalarType::JsonList => "Json".into(), + ScalarType::Bytes => "Bytes".into(), + ScalarType::Geometry(s) => s.clone(), }; DmmfTypeReference { - typ: stringified.into(), + typ: stringified, namespace: None, location: TypeLocation::Scalar, is_list: false, diff --git a/query-compiler/dmmf/src/tests/tests.rs b/query-compiler/dmmf/src/tests/tests.rs index 48d5b4fb567d..0449ea7b84b0 100644 --- a/query-compiler/dmmf/src/tests/tests.rs +++ b/query-compiler/dmmf/src/tests/tests.rs @@ -1,5 +1,56 @@ use crate::{dmmf_from_schema, tests::setup::*}; +#[test] +fn geometry_fields_in_datamodel_and_schema_dmmf() { + let schema = r#" + datasource db { + provider = "postgresql" + } + + generator client { + provider = "prisma-client" + } + + model Location { + id Int @id + position Geometry(Point, 4326) + path Geometry(LineString)? + } + "#; + + let dmmf = dmmf_from_schema(schema); + let location = dmmf + .data_model + .models + .iter() + .find(|m| m.name == "Location") + .expect("Location model"); + let pos = location.fields.iter().find(|f| f.name == "position").unwrap(); + assert_eq!(pos.field_type, "geometry(Point,4326)"); + let path = location.fields.iter().find(|f| f.name == "path").unwrap(); + assert_eq!(path.field_type, "geometry(LineString,0)"); + + let schema_json = serde_json::to_value(&dmmf.schema).unwrap(); + let models = schema_json + .get("outputObjectTypes") + .and_then(|v| v.get("model")) + .and_then(|v| v.as_array()) + .expect("model output types"); + let location_out = models + .iter() + .find(|m| m.get("name").and_then(|n| n.as_str()) == Some("Location")) + .expect("Location output type"); + let fields = location_out.get("fields").and_then(|f| f.as_array()).unwrap(); + let pos_field = fields + .iter() + .find(|f| f.get("name").and_then(|n| n.as_str()) == Some("position")); + let out_pos = pos_field.and_then(|f| f.get("outputType")).expect("position output"); + assert_eq!( + out_pos.get("type").and_then(|t| t.as_str()), + Some("geometry(Point,4326)") + ); +} + #[test] fn sqlite_ignore() { let dmmf = dmmf_from_schema(include_str!("./test-schemas/sqlite_ignore.prisma")); diff --git a/query-compiler/query-builders/sql-query-builder/src/convert.rs b/query-compiler/query-builders/sql-query-builder/src/convert.rs index 5563ec467dfc..22348deeb889 100644 --- a/query-compiler/query-builders/sql-query-builder/src/convert.rs +++ b/query-compiler/query-builders/sql-query-builder/src/convert.rs @@ -221,6 +221,7 @@ pub fn type_identifier_to_opaque_type(identifier: &TypeIdentifier) -> OpaqueType TypeIdentifier::Json => OpaqueType::Json, TypeIdentifier::DateTime => OpaqueType::DateTime, TypeIdentifier::Bytes => OpaqueType::Bytes, + TypeIdentifier::Geometry(_) => OpaqueType::Bytes, TypeIdentifier::Extension(_) | TypeIdentifier::Unsupported => OpaqueType::Unknown, } } diff --git a/query-compiler/query-builders/sql-query-builder/src/model_extensions/scalar_field.rs b/query-compiler/query-builders/sql-query-builder/src/model_extensions/scalar_field.rs index 92bc3b9adc9f..c604afe05753 100644 --- a/query-compiler/query-builders/sql-query-builder/src/model_extensions/scalar_field.rs +++ b/query-compiler/query-builders/sql-query-builder/src/model_extensions/scalar_field.rs @@ -68,6 +68,7 @@ impl ScalarFieldExt for ScalarField { TypeIdentifier::Int => Value::null_int32(), TypeIdentifier::BigInt => Value::null_int64(), TypeIdentifier::Bytes => Value::null_bytes(), + TypeIdentifier::Geometry(_) => Value::null_bytes(), TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), }, (PrismaValue::Placeholder(PrismaValuePlaceholder { name, .. }), ident) => { @@ -107,6 +108,7 @@ impl ScalarFieldExt for ScalarField { TypeIdentifier::Json => TypeFamily::Text(Some(TypeDataLength::Maximum)), TypeIdentifier::DateTime => TypeFamily::DateTime, TypeIdentifier::Bytes => TypeFamily::Text(parse_scalar_length(self)), + TypeIdentifier::Geometry(_) => TypeFamily::Text(parse_scalar_length(self)), TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), } } diff --git a/query-compiler/query-compiler/src/data_mapper.rs b/query-compiler/query-compiler/src/data_mapper.rs index 68ed20a24eef..caa588b205ec 100644 --- a/query-compiler/query-compiler/src/data_mapper.rs +++ b/query-compiler/query-compiler/src/data_mapper.rs @@ -17,6 +17,22 @@ use query_structure::{ use serde::Serialize; use std::{borrow::Cow, collections::HashMap, fmt}; +/// Maps a DMMF geometry field type string (e.g. `geometry(Point,4326)`) to the JSON protocol discriminator +/// consumed by `@prisma/client-engine-runtime`. +fn geometry_json_geometry_type(dmmf_type: &str) -> String { + let inner = dmmf_type.strip_prefix("geometry(").and_then(|s| s.strip_suffix(')')); + let subtype = inner + .and_then(|s| s.split(',').next()) + .map(str::trim) + .unwrap_or("Geometry"); + match subtype { + "Point" => "point".to_owned(), + "LineString" => "linestring".to_owned(), + "Polygon" => "polygon".to_owned(), + _ => "geometry".to_owned(), + } +} + pub fn map_result_structure(graph: &QueryGraph, builder: &mut ResultNodeBuilder) -> Option { graph .result_nodes() @@ -409,6 +425,9 @@ pub enum FieldScalarType { Bytes { encoding: ByteArrayEncoding, }, + Geometry { + geometry_type: String, + }, Unsupported, } @@ -427,6 +446,7 @@ impl fmt::Display for FieldScalarType { Self::Object => write!(f, "Object"), Self::DateTime => write!(f, "DateTime"), Self::Bytes { .. } => write!(f, "Bytes"), + Self::Geometry { geometry_type } => write!(f, "Geometry({geometry_type})"), Self::Unsupported => write!(f, "Unsupported"), } } @@ -434,7 +454,7 @@ impl fmt::Display for FieldScalarType { impl From<&Type> for FieldScalarType { fn from(typ: &Type) -> Self { - match typ.id { + match &typ.id { TypeIdentifier::String => Self::String, TypeIdentifier::Int => Self::Int, TypeIdentifier::BigInt => Self::BigInt, @@ -442,14 +462,14 @@ impl From<&Type> for FieldScalarType { TypeIdentifier::Decimal => Self::Decimal, TypeIdentifier::Boolean => Self::Boolean, TypeIdentifier::Enum(id) => Self::Enum { - name: typ.dm.clone().zip(id).name().to_owned(), + name: typ.dm.clone().zip(*id).name().to_owned(), }, TypeIdentifier::Extension(id) => Self::Extension { name: typ .dm .schema .db - .get_extension_type_prisma_name(id) + .get_extension_type_prisma_name(*id) .expect("extension type not found") .to_owned(), }, @@ -459,6 +479,9 @@ impl From<&Type> for FieldScalarType { TypeIdentifier::Bytes => Self::Bytes { encoding: ByteArrayEncoding::default(), }, + TypeIdentifier::Geometry(dmmf) => Self::Geometry { + geometry_type: geometry_json_geometry_type(dmmf), + }, TypeIdentifier::Unsupported => Self::Unsupported, } } diff --git a/query-compiler/query-compiler/tests/data/geometry-find-many.json b/query-compiler/query-compiler/tests/data/geometry-find-many.json new file mode 100644 index 000000000000..01457f83cf07 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-find-many.json @@ -0,0 +1,11 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": {}, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/schema.prisma b/query-compiler/query-compiler/tests/data/schema.prisma index 26f3c7eb6ad8..310f7af17d45 100644 --- a/query-compiler/query-compiler/tests/data/schema.prisma +++ b/query-compiler/query-compiler/tests/data/schema.prisma @@ -85,6 +85,11 @@ model DataTypes { intArray Int[] } +model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326) +} + model Patient { id Int @id @default(autoincrement()) userId Int @unique diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap new file mode 100644 index 000000000000..0d804411f0e1 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap @@ -0,0 +1,11 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-find-many.json +--- +dataMap { + id: Int (id) + position: Geometry(point) (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0"» +params [] diff --git a/query-compiler/query-structure/src/field/mod.rs b/query-compiler/query-structure/src/field/mod.rs index 2b118680f746..d3c7cc425ead 100644 --- a/query-compiler/query-structure/src/field/mod.rs +++ b/query-compiler/query-structure/src/field/mod.rs @@ -130,7 +130,7 @@ impl Field { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[allow(clippy::upper_case_acronyms)] pub enum TypeIdentifier { String, @@ -145,6 +145,8 @@ pub enum TypeIdentifier { Json, DateTime, Bytes, + /// String-encoded DMMF geometry type, e.g. `geometry(Point,4326)`. + Geometry(String), Unsupported, } @@ -171,7 +173,7 @@ pub type Type = Zipper; impl Type { pub fn type_name(&self) -> Cow<'static, str> { - match self.id { + match &self.id { TypeIdentifier::String => "String".into(), TypeIdentifier::Int => "Int".into(), TypeIdentifier::BigInt => "BigInt".into(), @@ -179,14 +181,14 @@ impl Type { TypeIdentifier::Decimal => "Decimal".into(), TypeIdentifier::Boolean => "Bool".into(), TypeIdentifier::Enum(enum_id) => { - let enum_name = self.dm.walk(enum_id).name(); + let enum_name = self.dm.walk(*enum_id).name(); format!("Enum{enum_name}").into() } TypeIdentifier::Extension(ext_id) => self .dm .schema .db - .get_extension_type_prisma_name(ext_id) + .get_extension_type_prisma_name(*ext_id) .expect("extension type name should be present") .to_owned() .into(), @@ -194,12 +196,13 @@ impl Type { TypeIdentifier::Json => "Json".into(), TypeIdentifier::DateTime => "DateTime".into(), TypeIdentifier::Bytes => "Bytes".into(), + TypeIdentifier::Geometry(s) => s.clone().into(), TypeIdentifier::Unsupported => "Unsupported".into(), } } pub fn to_prisma_type(&self) -> PrismaValueType { - match self.id { + match &self.id { TypeIdentifier::String => PrismaValueType::String, TypeIdentifier::Int => PrismaValueType::Int, TypeIdentifier::BigInt => PrismaValueType::BigInt, @@ -211,6 +214,7 @@ impl Type { TypeIdentifier::Json => PrismaValueType::Json, TypeIdentifier::DateTime => PrismaValueType::DateTime, TypeIdentifier::Bytes => PrismaValueType::Bytes, + TypeIdentifier::Geometry(_) => PrismaValueType::Bytes, TypeIdentifier::Extension(_) | TypeIdentifier::Unsupported => PrismaValueType::Any, } } diff --git a/query-compiler/query-structure/src/field/scalar.rs b/query-compiler/query-structure/src/field/scalar.rs index fd1e82073fce..254de601ba11 100644 --- a/query-compiler/query-structure/src/field/scalar.rs +++ b/query-compiler/query-structure/src/field/scalar.rs @@ -2,11 +2,15 @@ use crate::{DefaultKind, NativeTypeInstance, ValueGenerator, ast, parent_contain use chrono::{DateTime, FixedOffset}; use psl::{ generators::{DEFAULT_CUID_VERSION, DEFAULT_UUID_VERSION}, - parser_database::{self as db, ScalarFieldType, ScalarType, walkers}, + parser_database::{self as db, GeometrySpec, ScalarFieldType, ScalarType, walkers}, schema_ast::ast::FieldArity, }; use std::fmt::{Debug, Display}; +fn geometry_dmmf_string(spec: &GeometrySpec) -> String { + spec.postgres_sql_type() +} + pub type ScalarField = crate::Zipper; pub type ScalarFieldRef = ScalarField; @@ -98,6 +102,7 @@ impl ScalarField { ScalarFieldType::Enum(x) => TypeIdentifier::Enum(x), ScalarFieldType::Extension(udt) => TypeIdentifier::Extension(udt), ScalarFieldType::BuiltInScalar(scalar) => scalar.into(), + ScalarFieldType::Geometry(spec) => TypeIdentifier::Geometry(geometry_dmmf_string(&spec)), ScalarFieldType::Unsupported(_) => TypeIdentifier::Unsupported, } } diff --git a/query-compiler/query-structure/src/prisma_value_ext.rs b/query-compiler/query-structure/src/prisma_value_ext.rs index 4d6cd19b2bd6..946f31e393b9 100644 --- a/query-compiler/query-structure/src/prisma_value_ext.rs +++ b/query-compiler/query-structure/src/prisma_value_ext.rs @@ -10,7 +10,7 @@ pub(crate) trait PrismaValueExtensions { impl PrismaValueExtensions for PrismaValue { // Todo this is not exhaustive for now. fn coerce(self, to_type: &Type) -> crate::Result { - let coerced = match (self, to_type.id) { + let coerced = match (self, to_type.id.clone()) { // Trivial cases (PrismaValue::Null, _) => PrismaValue::Null, (val @ PrismaValue::String(_), TypeIdentifier::String) => val, @@ -23,6 +23,8 @@ impl PrismaValueExtensions for PrismaValue { (val @ PrismaValue::Uuid(_), TypeIdentifier::UUID) => val, (val @ PrismaValue::BigInt(_), TypeIdentifier::BigInt) => val, (val @ PrismaValue::Bytes(_), TypeIdentifier::Bytes) => val, + (val @ PrismaValue::Bytes(_), TypeIdentifier::Geometry(_)) => val, + (val @ PrismaValue::String(_), TypeIdentifier::Geometry(_)) => val, (val @ PrismaValue::Json(_), TypeIdentifier::Json) => val, // Valid String coercions diff --git a/query-compiler/request-handlers/src/protocols/json/protocol_adapter.rs b/query-compiler/request-handlers/src/protocols/json/protocol_adapter.rs index f20e898a7a73..2840a18afa74 100644 --- a/query-compiler/request-handlers/src/protocols/json/protocol_adapter.rs +++ b/query-compiler/request-handlers/src/protocols/json/protocol_adapter.rs @@ -209,6 +209,11 @@ impl<'a> JsonProtocolAdapter<'a> { decode_bytes(value).map(ArgumentValue::bytes).map_err(|_| build_err()) } + Some(custom_types::GEOMETRY) => { + let value = obj.get(custom_types::VALUE).ok_or_else(build_err)?; + let bytes = serde_json::to_vec(value).map_err(|_| build_err())?; + Ok(ArgumentValue::bytes(bytes)) + } Some(custom_types::JSON) => { let value = obj .remove(custom_types::VALUE) @@ -1238,6 +1243,96 @@ mod tests { "###); } + #[test] + fn custom_arg_geometry() { + let query: JsonSingleQuery = serde_json::from_str( + r#"{ + "modelName": "User", + "action": "updateOne", + "query": { + "arguments": { + "data": { + "x": { "$type": "Geometry", "value": { "type": "Point", "coordinates": [13.4, 52.5], "srid": 4326 } } + } + }, + "selection": { + "$scalars": true + } + } + }"#, + ) + .unwrap(); + + let operation = JsonProtocolAdapter::new(&schema()).convert_single(query).unwrap(); + + assert_debug_snapshot!(operation.arguments()[0].1, @r###" + Object( + { + "x": Scalar( + Bytes( + [ + 123, + 34, + 116, + 121, + 112, + 101, + 34, + 58, + 34, + 80, + 111, + 105, + 110, + 116, + 34, + 44, + 34, + 99, + 111, + 111, + 114, + 100, + 105, + 110, + 97, + 116, + 101, + 115, + 34, + 58, + 91, + 49, + 51, + 46, + 52, + 44, + 53, + 50, + 46, + 53, + 93, + 44, + 34, + 115, + 114, + 105, + 100, + 34, + 58, + 52, + 51, + 50, + 54, + 125, + ], + ), + ), + }, + ) + "###); + } + #[test] fn custom_arg_json() { let query: JsonSingleQuery = serde_json::from_str( diff --git a/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs b/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs index 5cb155c0892e..a132e3e78390 100644 --- a/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs +++ b/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs @@ -45,6 +45,9 @@ impl DataInputFieldMapper for UpdateDataInputFieldMapper { } TypeIdentifier::UUID => InputType::object(update_operations_object_type(ctx, "Uuid", sf.clone(), false)), TypeIdentifier::Bytes => InputType::object(update_operations_object_type(ctx, "Bytes", sf.clone(), false)), + TypeIdentifier::Geometry(_) => { + InputType::object(update_operations_object_type(ctx, "Bytes", sf.clone(), false)) + } TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), }; diff --git a/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs b/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs index fd72d44dbff5..86af2cdd5983 100644 --- a/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs +++ b/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs @@ -223,7 +223,12 @@ fn full_scalar_filter_type( include_aggregates: bool, ) -> InputObjectType<'_> { let native_type_name = native_type.as_ref().map(|nt| nt.name()); - let scalar_type_name = ctx.internal_data_model.clone().zip(typ).type_name().into_owned(); + let scalar_type_name = ctx + .internal_data_model + .clone() + .zip(typ.clone()) + .type_name() + .into_owned(); let type_name = ctx.connector.scalar_filter_name(scalar_type_name, native_type_name); let ident = Identifier::new_prisma(scalar_filter_name( &type_name, @@ -236,8 +241,8 @@ fn full_scalar_filter_type( let mut object = init_input_object_type(ident); object.set_fields(move || { - let mapped_scalar_type = map_scalar_input_type(ctx, typ, list); - let mut fields: Vec<_> = match typ { + let mapped_scalar_type = map_scalar_input_type(ctx, typ.clone(), list); + let mut fields: Vec<_> = match &typ { TypeIdentifier::String | TypeIdentifier::UUID => equality_filters(mapped_scalar_type.clone(), nullable) .chain(inclusion_filters(ctx, mapped_scalar_type.clone(), nullable)) .chain(alphanumeric_filters(ctx, mapped_scalar_type.clone())) @@ -274,6 +279,8 @@ fn full_scalar_filter_type( TypeIdentifier::Boolean => equality_filters(mapped_scalar_type.clone(), nullable).collect(), + TypeIdentifier::Geometry(_) => equality_filters(mapped_scalar_type.clone(), nullable).collect(), + TypeIdentifier::Bytes | TypeIdentifier::Enum(_) => equality_filters(mapped_scalar_type.clone(), nullable) .chain(inclusion_filters(ctx, mapped_scalar_type.clone(), nullable)) .collect(), @@ -285,7 +292,7 @@ fn full_scalar_filter_type( fields.push(not_filter_field( ctx, - typ, + typ.clone(), native_type.clone(), mapped_scalar_type, nullable, @@ -303,7 +310,7 @@ fn full_scalar_filter_type( )); if typ.is_numeric() { - let avg_type = map_avg_type_ident(typ); + let avg_type = map_avg_type_ident(typ.clone()); fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_AVG, @@ -315,7 +322,7 @@ fn full_scalar_filter_type( fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_SUM, - typ, + typ.clone(), nullable, list, )); @@ -325,7 +332,7 @@ fn full_scalar_filter_type( fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_MIN, - typ, + typ.clone(), nullable, list, )); @@ -333,7 +340,7 @@ fn full_scalar_filter_type( fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_MAX, - typ, + typ.clone(), nullable, list, )); diff --git a/query-compiler/schema/src/build/input_types/mod.rs b/query-compiler/schema/src/build/input_types/mod.rs index 0ec786546f21..d3df06e9c82e 100644 --- a/query-compiler/schema/src/build/input_types/mod.rs +++ b/query-compiler/schema/src/build/input_types/mod.rs @@ -23,6 +23,7 @@ fn map_scalar_input_type(ctx: &'_ QuerySchema, typ: TypeIdentifier, list: bool) TypeIdentifier::Extension(_) => unreachable!("No extension field should reach this path"), TypeIdentifier::Bytes => InputType::bytes(), TypeIdentifier::BigInt => InputType::bigint(), + TypeIdentifier::Geometry(s) => InputType::Scalar(ScalarType::Geometry(s)), TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), }; diff --git a/query-compiler/schema/src/build/output_types/field.rs b/query-compiler/schema/src/build/output_types/field.rs index b23288fc7308..50ef922e4b21 100644 --- a/query-compiler/schema/src/build/output_types/field.rs +++ b/query-compiler/schema/src/build/output_types/field.rs @@ -40,6 +40,7 @@ pub(crate) fn map_scalar_output_type<'a>(ctx: &'a QuerySchema, typ: &TypeIdentif TypeIdentifier::Int => OutputType::int(), TypeIdentifier::Bytes => OutputType::bytes(), TypeIdentifier::BigInt => OutputType::bigint(), + TypeIdentifier::Geometry(s) => OutputType::geometry(s.clone()), TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), }; diff --git a/query-compiler/schema/src/output_types.rs b/query-compiler/schema/src/output_types.rs index 2a1eb7bb319e..a27ee5db56ad 100644 --- a/query-compiler/schema/src/output_types.rs +++ b/query-compiler/schema/src/output_types.rs @@ -76,6 +76,10 @@ impl<'a> OutputType<'a> { InnerOutputType::Scalar(ScalarType::Bytes) } + pub(crate) fn geometry(dmmf_type: String) -> InnerOutputType<'a> { + InnerOutputType::Scalar(ScalarType::Geometry(dmmf_type)) + } + /// Attempts to recurse through the type until an object type is found. /// Returns Some(ObjectTypeStrongRef) if ab object type is found, None otherwise. pub fn as_object_type<'b>(&'b self) -> Option<&'b ObjectType<'a>> { diff --git a/query-compiler/schema/src/query_schema.rs b/query-compiler/schema/src/query_schema.rs index 58e850d5946e..029e0065044c 100644 --- a/query-compiler/schema/src/query_schema.rs +++ b/query-compiler/schema/src/query_schema.rs @@ -359,7 +359,7 @@ impl Identifier { } } -#[derive(Debug, Clone, PartialEq, Copy)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ScalarType { Null, String, @@ -373,6 +373,8 @@ pub enum ScalarType { JsonList, UUID, Bytes, + /// DMMF string form, e.g. `geometry(Point,4326)`. + Geometry(String), } impl fmt::Display for ScalarType { @@ -390,6 +392,7 @@ impl fmt::Display for ScalarType { ScalarType::UUID => "UUID", ScalarType::JsonList => "Json", ScalarType::Bytes => "Bytes", + ScalarType::Geometry(s) => s.as_str(), }; f.write_str(typ) diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs index 3fc93e50078e..f1c62e5feecd 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs @@ -549,6 +549,10 @@ impl SqlRenderer for PostgresRenderer { fn render_column_type(col: TableColumnWalker<'_>, renderer: &PostgresRenderer) -> Cow<'static, str> { let t = col.column_type(); + if let ColumnTypeFamily::Geometry(spec) = &t.family { + let sql = spec.postgres_sql_type(); + return format!("{sql}{}", if t.arity.is_list() { "[]" } else { "" }).into(); + } if let Some(enm) = col.column_type_family_as_enum() { let name = QuotedWithPrefix::pg_new(enm.explicit_namespace(), enm.name()); let arity = if t.arity.is_list() { "[]" } else { "" }; @@ -575,7 +579,12 @@ fn render_column_type_postgres(col: TableColumnWalker<'_>) -> Cow<'static, str> .expect("Missing native type in postgres_renderer::render_column_type()") { PostgresType::Known(known) => known, - PostgresType::Unknown(name, args) => return format!("{}({})", name, args.iter().format(", ")).into(), + PostgresType::Unknown(name, args) => { + if args.is_empty() { + return name.clone().into(); + } + return format!("{}({})", name, args.iter().format(", ")).into(); + } }; let tpe: Cow<'_, str> = match native_type { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs index df57b35205fc..f5eed3232cde 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs @@ -14,6 +14,7 @@ use enumflags2::BitFlags; use psl::builtin_connectors::{CockroachType, KnownPostgresType, PostgresType}; use regex::RegexSet; use sql_schema_describer::{ + ColumnTypeFamily, postgres::PostgresSchemaExt, walkers::{IndexWalker, TableColumnWalker}, }; @@ -382,6 +383,19 @@ fn postgres_column_type_change(columns: MigrationPair>) -> let from_list_to_scalar = columns.previous.arity().is_list() && !columns.next.arity().is_list(); let from_scalar_to_list = !columns.previous.arity().is_list() && columns.next.arity().is_list(); + match (columns.previous.column_type_family(), columns.next.column_type_family()) { + (ColumnTypeFamily::Geometry(prev), ColumnTypeFamily::Geometry(next)) => { + if from_list_to_scalar || from_scalar_to_list { + return Some(NotCastable); + } + if prev == next { + return None; + } + return Some(RiskyCast); + } + _ => {} + } + match (previous_type, next_type) { (_, Some(PostgresType::Known(KnownPostgresType::Text))) if from_list_to_scalar => Some(SafeCast), (_, Some(PostgresType::Known(KnownPostgresType::VarChar(None)))) if from_list_to_scalar => Some(SafeCast), diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/renderer.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/renderer.rs index 45553f29d097..5c0f77694c24 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/renderer.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/renderer.rs @@ -274,6 +274,7 @@ fn render_column_type(t: &ColumnType) -> &str { ColumnTypeFamily::Enum(_) => unreachable!("ColumnTypeFamily::Enum on SQLite"), ColumnTypeFamily::Uuid => unimplemented!("ColumnTypeFamily::Uuid on SQLite"), ColumnTypeFamily::Udt(_) => unimplemented!("ColumnTypeFamily::Udt on SQLite"), + ColumnTypeFamily::Geometry(_) => unreachable!("ColumnTypeFamily::Geometry on SQLite"), ColumnTypeFamily::Unsupported(x) => x.as_ref(), } } diff --git a/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs b/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs index ca20d237b9db..8d1a5aab88f0 100644 --- a/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs +++ b/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs @@ -107,6 +107,16 @@ impl<'a> ScalarFieldPair<'a> { sql::ColumnTypeFamily::Binary => Cow::from("Bytes"), sql::ColumnTypeFamily::Json => Cow::from("Json"), sql::ColumnTypeFamily::Uuid => Cow::from("String"), + sql::ColumnTypeFamily::Geometry(spec) => { + use std::fmt::Write; + let mut out = String::from("Geometry("); + out.push_str(spec.subtype.as_str()); + if let Some(srid) = spec.srid { + write!(&mut out, ", {srid}").unwrap(); + } + out.push(')'); + Cow::Owned(out) + } sql::ColumnTypeFamily::Enum(id) => self.context.enum_prisma_name(*id).prisma_name(), &sql::ColumnTypeFamily::Udt(id) => self .extension_type() @@ -153,6 +163,7 @@ impl<'a> ScalarFieldPair<'a> { sql::ColumnTypeFamily::Json => psl::parser_database::ScalarType::Json, sql::ColumnTypeFamily::Uuid => psl::parser_database::ScalarType::String, sql::ColumnTypeFamily::Binary => psl::parser_database::ScalarType::Bytes, + sql::ColumnTypeFamily::Geometry(spec) => return Some(ScalarFieldType::Geometry(*spec)), sql::ColumnTypeFamily::Udt(_) => { let entry = self.extension_type()?; return Some(ScalarFieldType::Extension(entry.id)); diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs index d560ec0a5c68..ee830d464240 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs @@ -11,8 +11,8 @@ use psl::{ ValidatedSchema, datamodel_connector::walker_ext_traits::*, parser_database::{ - self as db, ExtensionTypeId, ExtensionTypes, ReferentialAction, ScalarFieldType, ScalarType, SortOrder, - WhereClause, WhereCondition, WhereValue, ast, + self as db, ExtensionTypeId, ExtensionTypes, GeometrySpec, ReferentialAction, ScalarFieldType, ScalarType, + SortOrder, WhereClause, WhereCondition, WhereValue, ast, walkers::{IndexWalker, ModelWalker, ScalarFieldWalker}, }, }; @@ -444,10 +444,47 @@ fn push_column_for_scalar_field(field: ScalarFieldWalker<'_>, table_id: sql::Tab ScalarFieldType::BuiltInScalar(scalar_type) => { push_column_for_builtin_scalar_type(field, scalar_type, table_id, ctx) } + ScalarFieldType::Geometry(spec) => push_column_for_geometry_field(field, spec, table_id, ctx), ScalarFieldType::Unsupported(_) => push_column_for_model_unsupported_scalar_field(field, table_id, ctx), } } +fn push_column_for_geometry_field( + field: ScalarFieldWalker<'_>, + spec: GeometrySpec, + table_id: sql::TableId, + ctx: &mut Context<'_>, +) { + let connector = ctx.flavour.datamodel_connector(); + let native_type = field + .native_type_instance(connector) + .or_else(|| connector.default_native_type_for_scalar_type(&ScalarFieldType::Geometry(spec), ctx.datamodel)); + + let default = field.default_value().map(|def| { + sql::DefaultValue::db_generated::(unwrap_dbgenerated(def.value())) + .with_constraint_name(ctx.flavour.default_constraint_name(def)) + }); + + if let Some(default) = default { + let column_id = ctx.schema.describer_schema.next_table_column_id(); + ctx.schema.describer_schema.push_table_default_value(column_id, default); + } + + let column = sql::Column { + name: field.database_name().to_owned(), + tpe: sql::ColumnType { + family: sql::ColumnTypeFamily::Geometry(spec), + full_data_type: String::new(), + arity: column_arity(field.ast_field().arity), + native_type, + }, + auto_increment: false, + description: None, + }; + + ctx.schema.describer_schema.push_table_column(table_id, column); +} + fn push_column_for_model_enum_scalar_field( field: ScalarFieldWalker<'_>, enum_id: db::EnumId, diff --git a/schema-engine/sql-introspection-tests/tests/postgres/mod.rs b/schema-engine/sql-introspection-tests/tests/postgres/mod.rs index c490f667f4d8..8221694b0ab8 100644 --- a/schema-engine/sql-introspection-tests/tests/postgres/mod.rs +++ b/schema-engine/sql-introspection-tests/tests/postgres/mod.rs @@ -3,6 +3,7 @@ mod constraints; mod extensions; mod gin; mod gist; +mod postgis_geometry; mod spgist; use indoc::indoc; diff --git a/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs b/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs new file mode 100644 index 000000000000..7e61cc9e756d --- /dev/null +++ b/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs @@ -0,0 +1,29 @@ +use indoc::indoc; +use sql_introspection_tests::test_api::*; +use test_macros::test_connector; + +// PostGIS extension plus a table with multiple `geometry` columns (typmod / SRID variants). +#[test_connector(tags(Postgres), exclude(CockroachDb), preview_features("postgresqlExtensions"))] +async fn introspect_geometry_columns(api: &mut TestApi) -> TestResult { + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS postgis").await; + api.raw_cmd(indoc! {r#" + CREATE TABLE locations ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + position geometry(Point, 4326), + path geometry(LineString), + area geometry(Polygon, 3857) NOT NULL + ); + "#}) + .await; + + let schema = api.introspect().await?; + + assert!(schema.contains("extensions = [postgis")); + assert!(schema.contains("model locations")); + assert!(schema.contains("position Geometry(Point, 4326)?")); + assert!(schema.contains("path") && schema.contains("Geometry(LineString)")); + assert!(schema.contains("area Geometry(Polygon, 3857)")); + + Ok(()) +} diff --git a/schema-engine/sql-migration-tests/tests/migrations/postgres.rs b/schema-engine/sql-migration-tests/tests/migrations/postgres.rs index a9717a199d24..20bec44d618e 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/postgres.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/postgres.rs @@ -1,6 +1,7 @@ mod extensions; mod introspection; mod multi_schema; +mod postgis_geometry; use psl::parser_database::{NoExtensionTypes, SourceFile}; use quaint::Value; diff --git a/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs b/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs new file mode 100644 index 000000000000..022e91ebab09 --- /dev/null +++ b/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs @@ -0,0 +1,88 @@ +use indoc::indoc; +use psl::parser_database::NoExtensionTypes; +use schema_core::schema_connector::{CompositeTypeDepth, IntrospectionContext, SchemaConnector}; +use sql_migration_tests::test_api::*; +use test_macros::test_connector; + +#[test_connector(tags(Postgres), exclude(CockroachDb))] +fn create_table_with_geometry(api: TestApi) { + let dm = indoc! {r#" + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326)? + } + "#}; + + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS postgis"); + + api.schema_push_w_datasource(dm).send().assert_green(); + + let connector = psl::builtin_connectors::POSTGRES; + api.assert_schema().assert_table("Location", |table| { + table.assert_column("position", |col| { + col.assert_native_type("geometry(Point,4326)", connector) + }) + }); +} + +#[test_connector(tags(Postgres), exclude(CockroachDb))] +fn alter_geometry_srid(api: TestApi) { + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS postgis"); + + let schema1 = indoc! {r#" + model Location { + id Int @id + position Geometry(Point, 4326) + } + "#}; + + api.schema_push_w_datasource(schema1).send().assert_green(); + + let schema2 = indoc! {r#" + model Location { + id Int @id + position Geometry(Point, 3857) + } + "#}; + + api.schema_push_w_datasource(schema2).send().assert_green(); + + let connector = psl::builtin_connectors::POSTGRES; + api.assert_schema().assert_table("Location", |table| { + table.assert_column("position", |col| { + col.assert_native_type("geometry(Point,3857)", connector) + }) + }); +} + +#[test_connector(tags(Postgres), exclude(CockroachDb), preview_features("postgresqlExtensions"))] +fn geometry_round_trip(mut api: TestApi) { + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS postgis"); + + let dm = indoc! {r#" + model Location { + id Int @id + position Geometry(Point, 4326)? + path Geometry(LineString, 4326)? + } + "#}; + + api.schema_push_w_datasource(dm).send().assert_green(); + + let schema = api.datamodel_with_provider(dm); + let previous_schema = psl::validate_without_extensions(schema.into()); + let mut ctx = IntrospectionContext::new( + previous_schema, + CompositeTypeDepth::Infinite, + None, + std::path::PathBuf::new(), + ); + ctx.render_config = false; + + let introspected = tok(api.connector.introspect(&ctx, &NoExtensionTypes)) + .unwrap() + .into_single_datamodel(); + + assert!(introspected.contains("Geometry(Point, 4326)")); + assert!(introspected.contains("Geometry(LineString, 4326)")); +} diff --git a/schema-engine/sql-schema-describer/src/lib.rs b/schema-engine/sql-schema-describer/src/lib.rs index c731c5236b27..4a2658c0d713 100644 --- a/schema-engine/sql-schema-describer/src/lib.rs +++ b/schema-engine/sql-schema-describer/src/lib.rs @@ -29,6 +29,7 @@ use indexmap::IndexSet; pub use prisma_value::PrismaValue; use enumflags2::{BitFlag, BitFlags}; +use psl::parser_database::GeometrySpec; use regex::Regex; use serde::{Deserialize, Serialize}; use std::fmt::{self, Debug}; @@ -704,6 +705,8 @@ pub enum ColumnTypeFamily { Enum(EnumId), /// User-defined type Udt(UdtId), + /// PostGIS `geometry` type (subtype + SRID from typmod / `format_type`). + Geometry(GeometrySpec), /// Unsupported Unsupported(String), } @@ -747,6 +750,10 @@ impl ColumnTypeFamily { pub fn is_unsupported(&self) -> bool { matches!(self, ColumnTypeFamily::Unsupported(_)) } + + pub fn is_geometry(&self) -> bool { + matches!(self, ColumnTypeFamily::Geometry(_)) + } } /// A column's arity. diff --git a/schema-engine/sql-schema-describer/src/mssql.rs b/schema-engine/sql-schema-describer/src/mssql.rs index 28f7a8924f41..dbd3dc06a3db 100644 --- a/schema-engine/sql-schema-describer/src/mssql.rs +++ b/schema-engine/sql-schema-describer/src/mssql.rs @@ -326,6 +326,7 @@ impl<'a> SqlSchemaDescriber<'a> { ColumnTypeFamily::Udt(_) | ColumnTypeFamily::Unsupported(_) => { DefaultValue::db_generated(default_string) } + ColumnTypeFamily::Geometry(_) => DefaultValue::db_generated(default_string), ColumnTypeFamily::Enum(_) => unreachable!("No enums in MSSQL"), }; diff --git a/schema-engine/sql-schema-describer/src/mysql.rs b/schema-engine/sql-schema-describer/src/mysql.rs index ee353b9d9418..bf1b3ddf4ddb 100644 --- a/schema-engine/sql-schema-describer/src/mysql.rs +++ b/schema-engine/sql-schema-describer/src/mysql.rs @@ -478,7 +478,9 @@ impl<'a> SqlSchemaDescriber<'a> { ))) } } - ColumnTypeFamily::Udt(_) | ColumnTypeFamily::Unsupported(_) => match default_expression { + ColumnTypeFamily::Udt(_) + | ColumnTypeFamily::Unsupported(_) + | ColumnTypeFamily::Geometry(_) => match default_expression { true => Self::dbgenerated_expression(&default_string), false => DefaultValue::db_generated(default_string), }, diff --git a/schema-engine/sql-schema-describer/src/postgres.rs b/schema-engine/sql-schema-describer/src/postgres.rs index 97815057157f..ebcd7c7dfa7e 100644 --- a/schema-engine/sql-schema-describer/src/postgres.rs +++ b/schema-engine/sql-schema-describer/src/postgres.rs @@ -15,6 +15,7 @@ use indoc::indoc; use psl::{ builtin_connectors::{CockroachType, KnownPostgresType, PostgresType}, datamodel_connector::NativeTypeInstance, + parser_database::{GeometrySpec, GeometrySubtype, PostgisSpatialKind}, }; use quaint::{Value, connector::ResultRow, prelude::Queryable}; use regex::Regex; @@ -1532,6 +1533,62 @@ fn index_from_row( } } +fn map_geometry_subtype(pg_name: &str) -> GeometrySubtype { + match pg_name.trim().to_uppercase().as_str() { + "POINT" => GeometrySubtype::Point, + "LINESTRING" => GeometrySubtype::LineString, + "POLYGON" => GeometrySubtype::Polygon, + "MULTIPOINT" => GeometrySubtype::MultiPoint, + "MULTILINESTRING" => GeometrySubtype::MultiLineString, + "MULTIPOLYGON" => GeometrySubtype::MultiPolygon, + "GEOMETRYCOLLECTION" => GeometrySubtype::GeometryCollection, + "GEOMETRY" => GeometrySubtype::Geometry, + _ => GeometrySubtype::Geometry, + } +} + +/// Parse PostGIS `geometry` / `geography` from `format_type(atttypid, atttypmod)` (e.g. `geometry(Point,4326)`). +fn parse_postgis_spatial(formatted_type: &str, spatial: PostgisSpatialKind) -> GeometrySpec { + static RE_TWO: LazyLock = + LazyLock::new(|| Regex::new(r"(?i)^(geometry|geography)\s*\(\s*([^,()]+)\s*,\s*(-?\d+)\s*\)\s*$").unwrap()); + static RE_ONE: LazyLock = + LazyLock::new(|| Regex::new(r"(?i)^(geometry|geography)\s*\(\s*([^,()]+)\s*\)\s*$").unwrap()); + + let trimmed = formatted_type.trim(); + if trimmed.eq_ignore_ascii_case("geometry") || trimmed.eq_ignore_ascii_case("geography") { + return GeometrySpec { + subtype: GeometrySubtype::Geometry, + srid: None, + spatial, + }; + } + + if let Some(caps) = RE_TWO.captures(trimmed) { + let subtype_str = caps.get(2).map(|m| m.as_str()).unwrap_or("GEOMETRY"); + let srid: i32 = caps.get(3).and_then(|m| m.as_str().parse().ok()).unwrap_or(0); + return GeometrySpec { + subtype: map_geometry_subtype(subtype_str), + srid: Some(srid), + spatial, + }; + } + + if let Some(caps) = RE_ONE.captures(trimmed) { + let subtype_str = caps.get(2).map(|m| m.as_str()).unwrap_or("GEOMETRY"); + return GeometrySpec { + subtype: map_geometry_subtype(subtype_str), + srid: None, + spatial, + }; + } + + GeometrySpec { + subtype: GeometrySubtype::Geometry, + srid: None, + spatial, + } +} + fn get_column_type_postgresql(row: &ResultRow, schema: &SqlSchema) -> ColumnType { let data_type = row.get_expect_string("data_type"); let full_data_type = row.get_expect_string("full_data_type"); @@ -1564,6 +1621,24 @@ fn get_column_type_family( ) -> (ColumnTypeFamily, Option) { use ColumnTypeFamily::*; + if full_data_type == "geometry" && data_type == "USER-DEFINED" { + let spec = parse_postgis_spatial(&row.get_expect_string("formatted_type"), PostgisSpatialKind::Geometry); + let sql = spec.postgres_sql_type(); + return ( + ColumnTypeFamily::Geometry(spec), + Some(PostgresType::Unknown(sql, Vec::new())), + ); + } + + if full_data_type == "geography" && data_type == "USER-DEFINED" { + let spec = parse_postgis_spatial(&row.get_expect_string("formatted_type"), PostgisSpatialKind::Geography); + let sql = spec.postgres_sql_type(); + return ( + ColumnTypeFamily::Geometry(spec), + Some(PostgresType::Unknown(sql, Vec::new())), + ); + } + let precision = SqlSchemaDescriber::get_precision(row); let (t, nt) = match full_data_type { @@ -1749,3 +1824,45 @@ fn get_column_type_cockroachdb(row: &ResultRow, schema: &SqlSchema) -> ColumnTyp native_type: native_type.map(NativeTypeInstance::new::), } } + +#[cfg(test)] +mod postgis_geometry_tests { + use super::*; + + #[test] + fn parse_unconstrained_geometry() { + let spec = parse_postgis_spatial("geometry", PostgisSpatialKind::Geometry); + assert_eq!(spec.subtype, GeometrySubtype::Geometry); + assert_eq!(spec.srid, None); + assert_eq!(spec.spatial, PostgisSpatialKind::Geometry); + } + + #[test] + fn parse_geometry_point_with_srid() { + let spec = parse_postgis_spatial("geometry(POINT,4326)", PostgisSpatialKind::Geometry); + assert_eq!(spec.subtype, GeometrySubtype::Point); + assert_eq!(spec.srid, Some(4326)); + } + + #[test] + fn parse_geometry_multipolygon() { + let spec = parse_postgis_spatial("geometry(MULTIPOLYGON,3857)", PostgisSpatialKind::Geometry); + assert_eq!(spec.subtype, GeometrySubtype::MultiPolygon); + assert_eq!(spec.srid, Some(3857)); + } + + #[test] + fn parse_geography_point_with_srid() { + let spec = parse_postgis_spatial("geography(POINT,4326)", PostgisSpatialKind::Geography); + assert_eq!(spec.subtype, GeometrySubtype::Point); + assert_eq!(spec.srid, Some(4326)); + assert_eq!(spec.spatial, PostgisSpatialKind::Geography); + } + + #[test] + fn parse_geometry_point_subtype_only() { + let spec = parse_postgis_spatial("geometry(POINT)", PostgisSpatialKind::Geometry); + assert_eq!(spec.subtype, GeometrySubtype::Point); + assert_eq!(spec.srid, None); + } +} diff --git a/schema-engine/sql-schema-describer/src/postgres/default.rs b/schema-engine/sql-schema-describer/src/postgres/default.rs index aaa97d1955ca..96f4f3e2f698 100644 --- a/schema-engine/sql-schema-describer/src/postgres/default.rs +++ b/schema-engine/sql-schema-describer/src/postgres/default.rs @@ -102,6 +102,7 @@ fn parser_for_family(family: &ColumnTypeFamily) -> &'static dyn Fn(&mut Parser<' ColumnTypeFamily::DateTime => &parse_datetime_default, ColumnTypeFamily::Binary => &parse_binary_default, ColumnTypeFamily::Udt(_) | ColumnTypeFamily::Unsupported(_) | ColumnTypeFamily::Uuid => &parse_unsupported, + ColumnTypeFamily::Geometry(_) => &parse_unsupported, } } diff --git a/schema-engine/sql-schema-describer/src/postgres/default/c_style_scalar_lists.rs b/schema-engine/sql-schema-describer/src/postgres/default/c_style_scalar_lists.rs index 6029c1e99fb5..164022fb2b4c 100644 --- a/schema-engine/sql-schema-describer/src/postgres/default/c_style_scalar_lists.rs +++ b/schema-engine/sql-schema-describer/src/postgres/default/c_style_scalar_lists.rs @@ -87,6 +87,7 @@ fn parse_literal(s: &str, tpe: &ColumnType) -> Option { | ColumnTypeFamily::Binary | ColumnTypeFamily::Uuid | ColumnTypeFamily::Udt(_) + | ColumnTypeFamily::Geometry(_) | ColumnTypeFamily::Unsupported(_) => None, } } diff --git a/schema-engine/sql-schema-describer/src/sqlite.rs b/schema-engine/sql-schema-describer/src/sqlite.rs index 26982a12bce5..278999435099 100644 --- a/schema-engine/sql-schema-describer/src/sqlite.rs +++ b/schema-engine/sql-schema-describer/src/sqlite.rs @@ -341,7 +341,7 @@ async fn push_columns( ColumnTypeFamily::Binary => DefaultValue::db_generated(default_string), ColumnTypeFamily::Uuid => DefaultValue::db_generated(default_string), ColumnTypeFamily::Enum(_) => DefaultValue::value(PrismaValue::Enum(default_string)), - ColumnTypeFamily::Udt(_) | ColumnTypeFamily::Unsupported(_) => { + ColumnTypeFamily::Udt(_) | ColumnTypeFamily::Unsupported(_) | ColumnTypeFamily::Geometry(_) => { DefaultValue::db_generated(default_string) } }) From 3aa757acc95f50e9eca510aac1bc45e0e8a08247 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sat, 21 Mar 2026 16:14:41 +0700 Subject: [PATCH 2/6] feat(db): add Prisma-native querying support for PostGIS --- Cargo.lock | 1 + .../extractors/filters/scalar.rs | 135 ++++++++++++++++++ .../extractors/query_arguments.rs | 65 ++++++++- .../sql-query-builder/src/cursor_condition.rs | 1 + .../sql-query-builder/src/filter/visitor.rs | 94 ++++++++++++ .../sql-query-builder/src/ordering.rs | 38 +++++ .../sql-query-builder/src/select/mod.rs | 7 + query-compiler/query-structure/Cargo.toml | 1 + .../query-structure/src/filter/geometry.rs | 54 +++++++ .../query-structure/src/filter/mod.rs | 11 +- .../query-structure/src/order_by.rs | 54 +++++++ query-compiler/query-structure/src/record.rs | 1 + .../input_types/fields/field_filter_types.rs | 62 +++++++- .../input_types/objects/order_by_objects.rs | 35 +++++ query-compiler/schema/src/constants.rs | 14 ++ query-compiler/schema/src/identifier_type.rs | 8 ++ 16 files changed, 578 insertions(+), 3 deletions(-) create mode 100644 query-compiler/query-structure/src/filter/geometry.rs diff --git a/Cargo.lock b/Cargo.lock index 780a9952b560..9bb3682e63f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3803,6 +3803,7 @@ dependencies = [ "itertools 0.14.0", "prisma-value", "psl", + "serde_json", "thiserror 2.0.17", ] diff --git a/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs b/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs index 556e060b0872..acf57173d70f 100644 --- a/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs +++ b/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs @@ -187,6 +187,20 @@ impl<'a> ScalarFilterParser<'a> { aggregations::UNDERSCORE_MIN => self.aggregation_filter(input, Filter::min, false), aggregations::UNDERSCORE_MAX => self.aggregation_filter(input, Filter::max, false), + // Geometry filters + filters::NEAR => { + let input_map: ParsedInputMap<'_> = input.try_into()?; + Ok(vec![parse_geometry_near(field, input_map)?]) + } + filters::WITHIN => { + let input_map: ParsedInputMap<'_> = input.try_into()?; + Ok(vec![parse_geometry_within(field, input_map)?]) + } + filters::INTERSECTS => { + let input_map: ParsedInputMap<'_> = input.try_into()?; + Ok(vec![parse_geometry_intersects(field, input_map)?]) + } + _ => Err(QueryGraphBuilderError::InputError(format!( "{filter_name} is not a valid scalar filter operation" ))), @@ -623,3 +637,124 @@ fn coerce_json_null(value: ConditionValue) -> ConditionValue { _ => value, } } + +fn parse_geometry_near(field: &ScalarFieldRef, mut input_map: ParsedInputMap<'_>) -> QueryGraphBuilderResult { + let point_value = input_map + .swap_remove(filters::POINT) + .ok_or_else(|| QueryGraphBuilderError::InputError("near filter requires 'point' field".to_owned()))?; + let max_distance_value = input_map + .swap_remove(filters::MAX_DISTANCE) + .ok_or_else(|| QueryGraphBuilderError::InputError("near filter requires 'maxDistance' field".to_owned()))?; + let srid_value = input_map.swap_remove(filters::SRID); + + let point_list: Vec = point_value.try_into()?; + if point_list.len() != 2 { + return Err(QueryGraphBuilderError::InputError( + "near filter point must have exactly 2 coordinates".to_owned(), + )); + } + + let lon = extract_float(&point_list[0])?; + let lat = extract_float(&point_list[1])?; + let max_distance = extract_float(&max_distance_value.try_into()?)?; + let srid = srid_value.map(|v| extract_int(&v.try_into()?)).transpose()?; + + Ok(Filter::Geometry(GeometryFilter { + field: field.clone(), + condition: GeometryFilterCondition::Near { + point: (lon, lat), + max_distance, + srid, + }, + })) +} + +fn parse_geometry_within(field: &ScalarFieldRef, mut input_map: ParsedInputMap<'_>) -> QueryGraphBuilderResult { + let polygon_value = input_map + .swap_remove(filters::POLYGON) + .ok_or_else(|| QueryGraphBuilderError::InputError("within filter requires 'polygon' field".to_owned()))?; + let srid_value = input_map.swap_remove(filters::SRID); + + let polygon_outer: Vec = polygon_value.try_into()?; + let mut polygon = Vec::with_capacity(polygon_outer.len()); + + for coord in polygon_outer { + if let PrismaValue::List(pair) = coord { + if pair.len() != 2 { + return Err(QueryGraphBuilderError::InputError( + "polygon coordinates must be [lon, lat] pairs".to_owned(), + )); + } + let lon = extract_float(&pair[0])?; + let lat = extract_float(&pair[1])?; + polygon.push((lon, lat)); + } else { + return Err(QueryGraphBuilderError::InputError( + "polygon must be an array of coordinate pairs".to_owned(), + )); + } + } + + let srid = srid_value.map(|v| extract_int(&v.try_into()?)).transpose()?; + + Ok(Filter::Geometry(GeometryFilter { + field: field.clone(), + condition: GeometryFilterCondition::Within { polygon, srid }, + })) +} + +fn parse_geometry_intersects( + field: &ScalarFieldRef, + mut input_map: ParsedInputMap<'_>, +) -> QueryGraphBuilderResult { + let geometry_value = input_map + .swap_remove(filters::GEOMETRY) + .ok_or_else(|| QueryGraphBuilderError::InputError("intersects filter requires 'geometry' field".to_owned()))?; + let srid_value = input_map.swap_remove(filters::SRID); + + let geometry_json: PrismaValue = geometry_value.try_into()?; + let geometry = match geometry_json { + PrismaValue::Json(json_str) => serde_json::from_str(&json_str) + .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid GeoJSON: {}", e)))?, + PrismaValue::Object(obj) => serde_json::to_value(obj) + .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid GeoJSON object: {}", e)))?, + _ => { + return Err(QueryGraphBuilderError::InputError( + "intersects geometry must be a JSON value".to_owned(), + )) + } + }; + + let srid = srid_value.map(|v| extract_int(&v.try_into()?)).transpose()?; + + Ok(Filter::Geometry(GeometryFilter { + field: field.clone(), + condition: GeometryFilterCondition::Intersects { geometry, srid }, + })) +} + +fn extract_float(value: &PrismaValue) -> QueryGraphBuilderResult { + match value { + PrismaValue::Int(i) => Ok(*i as f64), + PrismaValue::BigInt(i) => Ok(*i as f64), + PrismaValue::Float(d) => d + .to_string() + .parse::() + .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid float value: {}", e))), + _ => Err(QueryGraphBuilderError::InputError(format!( + "Expected numeric value, got {:?}", + value + ))), + } +} + +fn extract_int(value: &PrismaValue) -> QueryGraphBuilderResult { + match value { + PrismaValue::Int(i) => Ok(*i as i32), + PrismaValue::BigInt(i) => Ok(*i as i32), + _ => Err(QueryGraphBuilderError::InputError(format!( + "Expected integer value, got {:?}", + value + ))), + } +} diff --git a/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs b/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs index 704d3d8cc0f5..b3108d33e6df 100644 --- a/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs +++ b/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs @@ -138,10 +138,17 @@ fn process_order_object( } Field::Scalar(sf) => { + if matches!(sf.type_identifier(), TypeIdentifier::Geometry(_)) { + if let ParsedInputValue::Map(ref map) = field_value { + if let Some(distance_from_value) = map.get(ordering::DISTANCE_FROM) { + return extract_geometry_distance_from(&sf, distance_from_value.clone(), path); + } + } + } + let (sort_order, nulls_order) = extract_order_by_args(field_value)?; if let Some(sort_aggr) = parent_sort_aggregation { - // If the parent is a sort aggregation then this scalar is part of that one. Ok(Some(OrderBy::scalar_aggregation(sf, sort_order, sort_aggr))) } else { Ok(Some(OrderBy::scalar(sf, path, sort_order, nulls_order))) @@ -362,3 +369,59 @@ fn finalize_arguments(mut args: QueryArguments, model: &Model) -> QueryGraphBuil Ok(args) } + +fn extract_geometry_distance_from( + field: &ScalarFieldRef, + value: ParsedInputValue<'_>, + path: Vec, +) -> QueryGraphBuilderResult> { + let mut object: ParsedInputMap<'_> = value.try_into()?; + + let point_value = object + .swap_remove(schema::constants::filters::POINT) + .ok_or_else(|| QueryGraphBuilderError::InputError("distanceFrom requires 'point' field".to_owned()))?; + let direction_value = object + .swap_remove(ordering::DIRECTION) + .ok_or_else(|| QueryGraphBuilderError::InputError("distanceFrom requires 'direction' field".to_owned()))?; + let srid_value = object.swap_remove(schema::constants::filters::SRID); + + let point_list: Vec = point_value.try_into()?; + if point_list.len() != 2 { + return Err(QueryGraphBuilderError::InputError( + "distanceFrom point must have exactly 2 coordinates".to_owned(), + )); + } + + let lon = extract_float_from_pv(&point_list[0])?; + let lat = extract_float_from_pv(&point_list[1])?; + let sort_order = pv_to_sort_order(direction_value.try_into()?)?; + let srid = srid_value.map(|v| extract_int_from_pv(&v.try_into()?)).transpose()?; + + Ok(Some(OrderBy::geometry(field.clone(), path, (lon, lat), sort_order, srid))) +} + +fn extract_float_from_pv(value: &PrismaValue) -> QueryGraphBuilderResult { + match value { + PrismaValue::Int(i) => Ok(*i as f64), + PrismaValue::BigInt(i) => Ok(*i as f64), + PrismaValue::Float(d) => d + .to_string() + .parse::() + .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid float value: {}", e))), + _ => Err(QueryGraphBuilderError::InputError(format!( + "Expected numeric value, got {:?}", + value + ))), + } +} + +fn extract_int_from_pv(value: &PrismaValue) -> QueryGraphBuilderResult { + match value { + PrismaValue::Int(i) => Ok(*i as i32), + PrismaValue::BigInt(i) => Ok(*i as i32), + _ => Err(QueryGraphBuilderError::InputError(format!( + "Expected integer value, got {:?}", + value + ))), + } +} diff --git a/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs b/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs index ddaf9d39e71a..505450d763ad 100644 --- a/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs +++ b/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs @@ -442,6 +442,7 @@ fn order_definitions( OrderBy::ScalarAggregation(order_by) => cursor_order_def_aggregation_scalar(order_by, order_by_def), OrderBy::ToManyAggregation(order_by) => cursor_order_def_aggregation_rel(order_by, order_by_def), OrderBy::Relevance(order_by) => cursor_order_def_relevance(order_by, order_by_def), + OrderBy::Geometry(_) => unimplemented!("Cursor-based pagination with geometry orderBy is not yet supported"), }) .collect_vec() } diff --git a/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs b/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs index 8aa82b778d34..90ed1acae443 100644 --- a/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs +++ b/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs @@ -21,6 +21,7 @@ pub(crate) trait FilterVisitorExt { ) -> (ConditionTree<'static>, Option>); fn visit_scalar_filter(&mut self, filter: ScalarFilter, ctx: &Context<'_>) -> ConditionTree<'static>; fn visit_scalar_list_filter(&mut self, filter: ScalarListFilter, ctx: &Context<'_>) -> ConditionTree<'static>; + fn visit_geometry_filter(&mut self, filter: GeometryFilter, ctx: &Context<'_>) -> ConditionTree<'static>; fn visit_one_relation_is_null_filter( &mut self, filter: OneRelationIsNullFilter, @@ -315,6 +316,7 @@ impl FilterVisitorExt for FilterVisitor { } }, Filter::Scalar(filter) => (self.visit_scalar_filter(filter, ctx), None), + Filter::Geometry(filter) => (self.visit_geometry_filter(filter, ctx), None), Filter::OneRelationIsNull(filter) => self.visit_one_relation_is_null_filter(filter, ctx), Filter::Relation(filter) => self.visit_relation_filter(filter, ctx), Filter::BoolFilter(b) => { @@ -614,6 +616,89 @@ impl FilterVisitorExt for FilterVisitor { ConditionTree::single(condition) } + + fn visit_geometry_filter(&mut self, filter: GeometryFilter, ctx: &Context<'_>) -> ConditionTree<'static> { + let field_column = filter.field.as_column(ctx); + let field_ref = format!("\"{}\"", field_column.name); + + let srid = match &filter.condition { + GeometryFilterCondition::Near { srid, .. } + | GeometryFilterCondition::Within { srid, .. } + | GeometryFilterCondition::Intersects { srid, .. } => srid.unwrap_or(4326), + }; + + let use_geography = srid == 4326 || srid == 4269 || srid == 4167; + + let sql = match filter.condition { + GeometryFilterCondition::Near { + point, + max_distance, + .. + } => { + let (lon, lat) = point; + if use_geography { + format!( + "ST_DWithin({}::geography, ST_SetSRID(ST_MakePoint({}, {}), {})::geography, {})", + field_ref, lon, lat, srid, max_distance + ) + } else { + format!( + "ST_DWithin({}, ST_SetSRID(ST_MakePoint({}, {}), {}), {})", + field_ref, lon, lat, srid, max_distance + ) + } + } + GeometryFilterCondition::Within { polygon, .. } => { + let wkt = format_polygon_wkt(&polygon); + let escaped_wkt = wkt.replace('\'', "''"); + format!( + "ST_Within({}, ST_GeomFromText('{}', {}))", + field_ref, escaped_wkt, srid + ) + } + GeometryFilterCondition::Intersects { geometry, .. } => { + let geom_type = geometry.get("type").and_then(|v| v.as_str()).unwrap_or(""); + let wkt = match geom_type { + "Polygon" => { + if let Some(coords) = geometry.get("coordinates").and_then(|v| v.as_array()) { + let ring_strs: Vec = coords.iter() + .filter_map(|ring| { + ring.as_array().map(|points| { + let point_strs: Vec = points.iter() + .filter_map(|p| { + p.as_array().and_then(|arr| { + if arr.len() >= 2 { + Some(format!("{} {}", + arr[0].as_f64().unwrap_or(0.0), + arr[1].as_f64().unwrap_or(0.0))) + } else { + None + } + }) + }) + .collect(); + format!("({})", point_strs.join(", ")) + }) + }) + .collect(); + format!("POLYGON({})", ring_strs.join(", ")) + } else { + "POLYGON EMPTY".to_string() + } + } + _ => format!("POINT(0 0)"), + }; + let escaped_wkt = wkt.replace('\'', "''"); + format!( + "ST_Intersects({}, ST_GeomFromText('{}', {}))", + field_ref, escaped_wkt, srid + ) + } + }; + + let raw_expr: Expression = Value::enum_variant(sql).raw().into(); + ConditionTree::single(raw_expr) + } } fn scalar_filter_aliased_cond( @@ -1536,3 +1621,12 @@ impl JsonFilterExt for (Expression<'static>, Expression<'static>) { } } } + +fn format_polygon_wkt(polygon: &[(f64, f64)]) -> String { + let coords = polygon + .iter() + .map(|(x, y)| format!("{} {}", x, y)) + .collect::>() + .join(", "); + format!("POLYGON(({}))", coords) +} diff --git a/query-compiler/query-builders/sql-query-builder/src/ordering.rs b/query-compiler/query-builders/sql-query-builder/src/ordering.rs index a5040f41b9ee..1a22d39149ba 100644 --- a/query-compiler/query-builders/sql-query-builder/src/ordering.rs +++ b/query-compiler/query-builders/sql-query-builder/src/ordering.rs @@ -54,6 +54,7 @@ impl OrderByBuilder { reachable_only_with_capability!(ConnectorCapability::NativeFullTextSearch); self.build_order_relevance(order_by, needs_reversed_order, ctx) } + OrderBy::Geometry(order_by) => self.build_order_geometry(order_by, needs_reversed_order, ctx), }) .collect_vec() } @@ -279,6 +280,43 @@ impl OrderByBuilder { format!("{}{}", ORDER_JOIN_PREFIX, self.join_counter) } + + fn build_order_geometry( + &mut self, + order_by: &OrderByGeometry, + needs_reversed_order: bool, + ctx: &Context<'_>, + ) -> OrderByDefinition { + let parent_alias = self.parent_alias.clone(); + let joins: Vec = self.compute_one2m_join(&order_by.path, parent_alias.as_ref(), ctx); + + let parent_table = joins + .last() + .map(|j| j.alias.to_owned()) + .or_else(|| self.parent_alias.clone()); + let field_column = order_by.field.as_column(ctx).opt_table(parent_table); + + let (lon, lat) = order_by.point; + let srid = order_by.srid.unwrap_or(4326); + + let field_ref = format!("\"{}\"", field_column.name); + + let sql = format!( + "ST_Distance(CAST({} AS geography), CAST(ST_SetSRID(ST_MakePoint({}, {}), {}) AS geography))", + field_ref, lon, lat, srid + ); + + let distance_expr: Expression = Value::enum_variant(sql).raw().into(); + + let order = Some(into_order(&order_by.sort_order, None, needs_reversed_order)); + let order_definition: OrderDefinition = (distance_expr.clone(), order); + + OrderByDefinition { + order_column: distance_expr.clone(), + order_definition, + joins, + } + } } fn prisma_value_to_search_expression(pv: PrismaValue) -> Expression<'static> { diff --git a/query-compiler/query-builders/sql-query-builder/src/select/mod.rs b/query-compiler/query-builders/sql-query-builder/src/select/mod.rs index c505d8a95ac5..9a9879fda5da 100644 --- a/query-compiler/query-builders/sql-query-builder/src/select/mod.rs +++ b/query-compiler/query-builders/sql-query-builder/src/select/mod.rs @@ -589,6 +589,13 @@ fn order_by_selection(rs: &RelationSelection) -> FieldSelection { // This is necessary because the order by is done on a different join. The following hops are handled by the order by builder. OrderBy::ToManyAggregation(x) => first_hop_linking_fields(x.intermediary_hops()), OrderBy::ScalarAggregation(x) => vec![x.field.clone()], + OrderBy::Geometry(x) => { + if x.path.is_empty() { + vec![x.field.clone()] + } else { + first_hop_linking_fields(&x.path) + } + } }) .collect(); diff --git a/query-compiler/query-structure/Cargo.toml b/query-compiler/query-structure/Cargo.toml index 6cd11bb6dea6..ede4ff084d5d 100644 --- a/query-compiler/query-structure/Cargo.toml +++ b/query-compiler/query-structure/Cargo.toml @@ -9,6 +9,7 @@ itertools.workspace = true prisma-value.workspace = true bigdecimal.workspace = true thiserror.workspace = true +serde_json.workspace = true chrono.workspace = true indexmap.workspace = true diff --git a/query-compiler/query-structure/src/filter/geometry.rs b/query-compiler/query-structure/src/filter/geometry.rs new file mode 100644 index 000000000000..dffde355cdf8 --- /dev/null +++ b/query-compiler/query-structure/src/filter/geometry.rs @@ -0,0 +1,54 @@ +use crate::*; + +#[derive(Debug, Clone, PartialEq)] +pub struct GeometryFilter { + pub field: ScalarFieldRef, + pub condition: GeometryFilterCondition, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum GeometryFilterCondition { + Near { + point: (f64, f64), + max_distance: f64, + srid: Option, + }, + Within { + polygon: Vec<(f64, f64)>, + srid: Option, + }, + Intersects { + geometry: serde_json::Value, + srid: Option, + }, +} + +impl std::hash::Hash for GeometryFilter { + fn hash(&self, state: &mut H) { + self.field.hash(state); + match &self.condition { + GeometryFilterCondition::Near { point, max_distance, srid } => { + "Near".hash(state); + point.0.to_bits().hash(state); + point.1.to_bits().hash(state); + max_distance.to_bits().hash(state); + srid.hash(state); + } + GeometryFilterCondition::Within { polygon, srid } => { + "Within".hash(state); + for (x, y) in polygon { + x.to_bits().hash(state); + y.to_bits().hash(state); + } + srid.hash(state); + } + GeometryFilterCondition::Intersects { geometry, srid } => { + "Intersects".hash(state); + geometry.to_string().hash(state); + srid.hash(state); + } + } + } +} + +impl Eq for GeometryFilter {} diff --git a/query-compiler/query-structure/src/filter/mod.rs b/query-compiler/query-structure/src/filter/mod.rs index af351ae0d2bb..187a921cf69b 100644 --- a/query-compiler/query-structure/src/filter/mod.rs +++ b/query-compiler/query-structure/src/filter/mod.rs @@ -7,6 +7,7 @@ mod compare; mod composite; +mod geometry; mod into_filter; mod json; mod list; @@ -16,6 +17,7 @@ mod scalar; pub use compare::*; pub use composite::*; +pub use geometry::*; pub use into_filter::*; pub use json::*; pub use list::*; @@ -35,6 +37,7 @@ pub enum Filter { OneRelationIsNull(OneRelationIsNullFilter), Relation(RelationFilter), Composite(CompositeFilter), + Geometry(GeometryFilter), BoolFilter(bool), Aggregation(AggregationFilter), Empty, @@ -226,7 +229,7 @@ impl Filter { use Filter::*; match self { Not(branches) | Or(branches) | And(branches) => branches.iter().any(|filter| filter.has_relations()), - Scalar(..) | ScalarList(..) | Composite(..) | BoolFilter(..) | Empty => false, + Scalar(..) | ScalarList(..) | Composite(..) | Geometry(..) | BoolFilter(..) | Empty => false, Aggregation(filter) => match filter { Average(filter) | Count(filter) | Sum(filter) | Min(filter) | Max(filter) => filter.has_relations(), }, @@ -290,3 +293,9 @@ impl From for Filter { Filter::Composite(cf) } } + +impl From for Filter { + fn from(gf: GeometryFilter) -> Self { + Filter::Geometry(gf) + } +} diff --git a/query-compiler/query-structure/src/order_by.rs b/query-compiler/query-structure/src/order_by.rs index c9c4a2b96d67..5994599f377c 100644 --- a/query-compiler/query-structure/src/order_by.rs +++ b/query-compiler/query-structure/src/order_by.rs @@ -29,6 +29,7 @@ pub enum OrderBy { ScalarAggregation(OrderByScalarAggregation), ToManyAggregation(OrderByToManyAggregation), Relevance(OrderByRelevance), + Geometry(OrderByGeometry), } impl OrderBy { @@ -38,6 +39,7 @@ impl OrderBy { OrderBy::ToManyAggregation(o) => Some(&o.path), OrderBy::ScalarAggregation(_) => None, OrderBy::Relevance(_) => None, + OrderBy::Geometry(o) => Some(&o.path), } } @@ -47,6 +49,7 @@ impl OrderBy { OrderBy::ScalarAggregation(o) => o.sort_order, OrderBy::ToManyAggregation(o) => o.sort_order, OrderBy::Relevance(o) => o.sort_order, + OrderBy::Geometry(o) => o.sort_order, } } @@ -56,6 +59,7 @@ impl OrderBy { OrderBy::ScalarAggregation(o) => Some(o.field.clone()), OrderBy::ToManyAggregation(_) => None, OrderBy::Relevance(_) => None, + OrderBy::Geometry(o) => Some(o.field.clone()), } } @@ -113,6 +117,22 @@ impl OrderBy { path, }) } + + pub fn geometry( + field: ScalarFieldRef, + path: Vec, + point: (f64, f64), + sort_order: SortOrder, + srid: Option, + ) -> Self { + Self::Geometry(OrderByGeometry { + field, + path, + point, + sort_order, + srid, + }) + } } /// Describes a hop over to a relation or composite for an orderBy statement. @@ -211,6 +231,40 @@ pub struct OrderByRelevance { pub path: Vec, } +#[derive(Clone, PartialEq)] +pub struct OrderByGeometry { + pub field: ScalarFieldRef, + pub path: Vec, + pub point: (f64, f64), + pub sort_order: SortOrder, + pub srid: Option, +} + +impl std::fmt::Debug for OrderByGeometry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OrderByGeometry") + .field("field", &format!("{}", self.field)) + .field("path", &self.path) + .field("point", &self.point) + .field("sort_order", &self.sort_order) + .field("srid", &self.srid) + .finish() + } +} + +impl std::hash::Hash for OrderByGeometry { + fn hash(&self, state: &mut H) { + self.field.hash(state); + self.path.hash(state); + self.point.0.to_bits().hash(state); + self.point.1.to_bits().hash(state); + self.sort_order.hash(state); + self.srid.hash(state); + } +} + +impl Eq for OrderByGeometry {} + impl Display for SortOrder { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/query-compiler/query-structure/src/record.rs b/query-compiler/query-structure/src/record.rs index cfa527d9f9bf..b88a9769d924 100644 --- a/query-compiler/query-structure/src/record.rs +++ b/query-compiler/query-structure/src/record.rs @@ -80,6 +80,7 @@ impl ManyRecords { OrderBy::ScalarAggregation(_) => unimplemented!(), OrderBy::ToManyAggregation(_) => unimplemented!(), OrderBy::Relevance(_) => unimplemented!(), + OrderBy::Geometry(_) => unimplemented!(), }); orderings diff --git a/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs b/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs index 86af2cdd5983..48d099ec4a39 100644 --- a/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs +++ b/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs @@ -279,7 +279,10 @@ fn full_scalar_filter_type( TypeIdentifier::Boolean => equality_filters(mapped_scalar_type.clone(), nullable).collect(), - TypeIdentifier::Geometry(_) => equality_filters(mapped_scalar_type.clone(), nullable).collect(), + TypeIdentifier::Geometry(_) => equality_filters(mapped_scalar_type.clone(), nullable) + .chain(inclusion_filters(ctx, mapped_scalar_type.clone(), nullable)) + .chain(geometry_filters()) + .collect(), TypeIdentifier::Bytes | TypeIdentifier::Enum(_) => equality_filters(mapped_scalar_type.clone(), nullable) .chain(inclusion_filters(ctx, mapped_scalar_type.clone(), nullable)) @@ -619,3 +622,60 @@ fn not_filter_field<'a>( } } } + +fn geometry_near_input<'a>() -> InputObjectType<'a> { + let ident = Identifier::new_prisma(IdentifierType::GeometryNearInput); + let mut object = init_input_object_type(ident); + + object.set_fields(|| { + vec![ + simple_input_field(filters::POINT, InputType::list(InputType::float()), None).required(), + simple_input_field(filters::MAX_DISTANCE, InputType::float(), None).required(), + simple_input_field(filters::SRID, InputType::int(), None).optional(), + ] + }); + + object +} + +fn geometry_within_input<'a>() -> InputObjectType<'a> { + let ident = Identifier::new_prisma(IdentifierType::GeometryWithinInput); + let mut object = init_input_object_type(ident); + + object.set_fields(|| { + vec![ + simple_input_field( + filters::POLYGON, + InputType::list(InputType::list(InputType::float())), + None, + ) + .required(), + simple_input_field(filters::SRID, InputType::int(), None).optional(), + ] + }); + + object +} + +fn geometry_intersects_input<'a>() -> InputObjectType<'a> { + let ident = Identifier::new_prisma(IdentifierType::GeometryIntersectsInput); + let mut object = init_input_object_type(ident); + + object.set_fields(|| { + vec![ + simple_input_field(filters::GEOMETRY, InputType::json(), None).required(), + simple_input_field(filters::SRID, InputType::int(), None).optional(), + ] + }); + + object +} + +fn geometry_filters<'a>() -> impl Iterator> { + vec![ + simple_input_field(filters::NEAR, InputType::object(geometry_near_input()), None).optional(), + simple_input_field(filters::WITHIN, InputType::object(geometry_within_input()), None).optional(), + simple_input_field(filters::INTERSECTS, InputType::object(geometry_intersects_input()), None).optional(), + ] + .into_iter() +} diff --git a/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs b/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs index a668441b536c..287d8c5109e7 100644 --- a/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs +++ b/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs @@ -136,6 +136,10 @@ fn orderby_field_mapper<'a>( types.push(InputType::object(sort_nulls_object_type())); } + if matches!(sf.type_identifier(), TypeIdentifier::Geometry(_)) { + types.push(InputType::object(geometry_order_by_object_type())); + } + Some(input_field(sf.name().to_owned(), types, None).optional()) } @@ -273,3 +277,34 @@ fn order_by_object_type_text_search<'a>( }); input_object } + +fn geometry_distance_from_input<'a>() -> InputObjectType<'a> { + let ident = Identifier::new_prisma(IdentifierType::GeometryDistanceFromInput); + let mut object = init_input_object_type(ident); + + object.set_fields(|| { + vec![ + simple_input_field(constants::filters::POINT, InputType::list(InputType::float()), None).required(), + simple_input_field(ordering::DIRECTION, InputType::Enum(sort_order_enum()), None).required(), + simple_input_field(constants::filters::SRID, InputType::int(), None).optional(), + ] + }); + + object +} + +fn geometry_order_by_object_type<'a>() -> InputObjectType<'a> { + let ident = Identifier::new_prisma("GeometryOrderByInput"); + let mut object = init_input_object_type(ident); + + object.set_fields(|| { + vec![simple_input_field( + ordering::DISTANCE_FROM, + InputType::object(geometry_distance_from_input()), + None, + ) + .optional()] + }); + + object +} diff --git a/query-compiler/schema/src/constants.rs b/query-compiler/schema/src/constants.rs index 47adbb822351..d0b44c9f3995 100644 --- a/query-compiler/schema/src/constants.rs +++ b/query-compiler/schema/src/constants.rs @@ -110,6 +110,16 @@ pub mod filters { pub const STRING_STARTS_WITH: &str = "string_starts_with"; pub const STRING_ENDS_WITH: &str = "string_ends_with"; pub const JSON_TYPE: &str = "json_type"; + + // geometry filters + pub const NEAR: &str = "near"; + pub const WITHIN: &str = "within"; + pub const INTERSECTS: &str = "intersects"; + pub const POINT: &str = "point"; + pub const MAX_DISTANCE: &str = "maxDistance"; + pub const POLYGON: &str = "polygon"; + pub const GEOMETRY: &str = "geometry"; + pub const SRID: &str = "srid"; } pub mod aggregations { @@ -140,6 +150,10 @@ pub mod ordering { pub const SORT: &str = "sort"; pub const NULLS: &str = "nulls"; pub const FIELDS: &str = "fields"; + + // geometry ordering + pub const DISTANCE_FROM: &str = "distanceFrom"; + pub const DIRECTION: &str = "direction"; } pub mod json_null { diff --git a/query-compiler/schema/src/identifier_type.rs b/query-compiler/schema/src/identifier_type.rs index d05e6002b1c1..6f7af4fcf4b9 100644 --- a/query-compiler/schema/src/identifier_type.rs +++ b/query-compiler/schema/src/identifier_type.rs @@ -54,6 +54,10 @@ pub enum IdentifierType { UpdateManyAndReturnOutput(Model), WhereInput(ParentContainer), WhereUniqueInput(Model), + GeometryNearInput, + GeometryWithinInput, + GeometryIntersectsInput, + GeometryDistanceFromInput, Raw(String), } @@ -315,6 +319,10 @@ impl std::fmt::Display for IdentifierType { IdentifierType::UpdateManyAndReturnOutput(model) => { write!(f, "UpdateMany{}AndReturnOutputType", model.name()) } + IdentifierType::GeometryNearInput => f.write_str("GeometryNearInput"), + IdentifierType::GeometryWithinInput => f.write_str("GeometryWithinInput"), + IdentifierType::GeometryIntersectsInput => f.write_str("GeometryIntersectsInput"), + IdentifierType::GeometryDistanceFromInput => f.write_str("GeometryDistanceFromInput"), } } } From f4695b4e409163291b9767fac6e8f12d7cf06f59 Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sat, 21 Mar 2026 16:27:44 +0700 Subject: [PATCH 3/6] feat(db): enhance PostGIS querying with Prisma-native patterns and update tests --- .../tests/geometry-filters-graph-builds.rs | 411 ++++++++++++++++++ .../geometry-combined-scalar-spatial.json | 30 ++ .../data/geometry-count-with-filter.json | 26 ++ .../data/geometry-delete-with-filter.json | 19 + .../data/geometry-filter-and-orderby.json | 30 ++ .../geometry-filter-and-scalar-filter.json | 29 ++ .../data/geometry-filter-custom-srid.json | 21 + .../data/geometry-filter-intersects.json.skip | 30 ++ .../tests/data/geometry-filter-near.json | 20 + .../tests/data/geometry-filter-not-near.json | 22 + .../data/geometry-filter-or-multiple.json | 32 ++ .../tests/data/geometry-filter-within.json | 25 ++ .../tests/data/geometry-multiple-orderby.json | 25 ++ .../data/geometry-orderby-distance-asc.json | 22 + .../data/geometry-orderby-distance-desc.json | 22 + .../data/geometry-orderby-with-limit.json | 23 + .../query-compiler/tests/data/schema.prisma | 3 +- ...geometry-combined-scalar-spatial.json.snap | 16 + ...eries@geometry-count-with-filter.json.snap | 15 + ...ries@geometry-delete-with-filter.json.snap | 10 + ...ries@geometry-filter-and-orderby.json.snap | 15 + ...eometry-filter-and-scalar-filter.json.snap | 14 + ...ries@geometry-filter-custom-srid.json.snap | 14 + ...es__queries@geometry-filter-near.json.snap | 14 + ...queries@geometry-filter-not-near.json.snap | 14 + ...ries@geometry-filter-or-multiple.json.snap | 15 + ...__queries@geometry-filter-within.json.snap | 14 + ...ries__queries@geometry-find-many.json.snap | 3 +- ...ueries@geometry-multiple-orderby.json.snap | 15 + ...es@geometry-orderby-distance-asc.json.snap | 14 + ...s@geometry-orderby-distance-desc.json.snap | 14 + ...ries@geometry-orderby-with-limit.json.snap | 14 + 32 files changed, 989 insertions(+), 2 deletions(-) create mode 100644 query-compiler/core-tests/tests/geometry-filters-graph-builds.rs create mode 100644 query-compiler/query-compiler/tests/data/geometry-combined-scalar-spatial.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-count-with-filter.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-delete-with-filter.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-and-orderby.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-and-scalar-filter.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-custom-srid.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-intersects.json.skip create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-near.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-not-near.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-or-multiple.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-filter-within.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-multiple-orderby.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-orderby-distance-asc.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-orderby-distance-desc.json create mode 100644 query-compiler/query-compiler/tests/data/geometry-orderby-with-limit.json create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap create mode 100644 query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap diff --git a/query-compiler/core-tests/tests/geometry-filters-graph-builds.rs b/query-compiler/core-tests/tests/geometry-filters-graph-builds.rs new file mode 100644 index 000000000000..ace8692513e4 --- /dev/null +++ b/query-compiler/core-tests/tests/geometry-filters-graph-builds.rs @@ -0,0 +1,411 @@ +use std::sync::Arc; + +use query_core::{QueryDocument, QueryGraphBuilder}; +use request_handlers::{JsonBody, JsonSingleQuery, RequestBody}; + +#[test] +fn geometry_near_filter_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + generator client { + provider = "prisma-client" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326)? + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "near": { + "point": [2.35, 48.85], + "maxDistance": 100000 + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with near filter should compile to a query graph"); +} + +#[test] +fn geometry_within_filter_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326)? + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "within": { + "polygon": [ + [0, 0], + [0, 2], + [2, 2], + [2, 0], + [0, 0] + ] + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with within filter should compile to a query graph"); +} + +#[test] +fn geometry_orderby_distance_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326)? + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "orderBy": [ + { + "position": { + "distanceFrom": { + "point": [0, 0], + "direction": "asc" + } + } + } + ] + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with distanceFrom orderBy should compile to a query graph"); +} + +#[test] +fn geometry_combined_filter_and_orderby_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326)? + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 50000 + } + } + }, + "orderBy": [ + { + "position": { + "distanceFrom": { + "point": [0, 0], + "direction": "asc" + } + } + } + ] + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with geometry filter and orderBy should compile to a query graph"); +} + +#[test] +fn geometry_not_filter_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326)? + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "NOT": { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 10000 + } + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with NOT geometry filter should compile to a query graph"); +} + +#[test] +fn geometry_or_filter_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry(Point, 4326)? + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "OR": [ + { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 10000 + } + } + }, + { + "position": { + "near": { + "point": [10, 10], + "maxDistance": 5000 + } + } + } + ] + } + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with OR geometry filter should compile to a query graph"); +} + +#[test] +fn geometry_custom_srid_builds_query_graph() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model LocationMercator { + id Int @id @default(autoincrement()) + position Geometry(Point, 3857)? + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "LocationMercator", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "near": { + "point": [1000000, 6000000], + "maxDistance": 5000, + "srid": 3857 + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("findMany with custom SRID 3857 should compile to a query graph"); +} diff --git a/query-compiler/query-compiler/tests/data/geometry-combined-scalar-spatial.json b/query-compiler/query-compiler/tests/data/geometry-combined-scalar-spatial.json new file mode 100644 index 000000000000..cc952b514929 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-combined-scalar-spatial.json @@ -0,0 +1,30 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "AND": [ + { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 10000 + } + } + }, + { + "name": { + "startsWith": "Paris" + } + } + ] + } + }, + "selection": { + "id": true, + "name": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-count-with-filter.json b/query-compiler/query-compiler/tests/data/geometry-count-with-filter.json new file mode 100644 index 000000000000..f0252bb38180 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-count-with-filter.json @@ -0,0 +1,26 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "within": { + "polygon": [ + [-1, -1], + [-1, 5], + [5, 5], + [5, -1], + [-1, -1] + ] + } + } + }, + "take": 10 + }, + "selection": { + "$composites": true, + "$scalars": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-delete-with-filter.json b/query-compiler/query-compiler/tests/data/geometry-delete-with-filter.json new file mode 100644 index 000000000000..63f51fc11fc0 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-delete-with-filter.json @@ -0,0 +1,19 @@ +{ + "modelName": "Location", + "action": "deleteMany", + "query": { + "arguments": { + "where": { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 1000 + } + } + } + }, + "selection": { + "count": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-and-orderby.json b/query-compiler/query-compiler/tests/data/geometry-filter-and-orderby.json new file mode 100644 index 000000000000..6af734618f65 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-and-orderby.json @@ -0,0 +1,30 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 500000 + } + } + }, + "orderBy": [ + { + "position": { + "distanceFrom": { + "point": [0, 0], + "direction": "asc" + } + } + } + ] + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-and-scalar-filter.json b/query-compiler/query-compiler/tests/data/geometry-filter-and-scalar-filter.json new file mode 100644 index 000000000000..4084786b154d --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-and-scalar-filter.json @@ -0,0 +1,29 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "AND": [ + { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 50000 + } + } + }, + { + "id": { + "gt": 100 + } + } + ] + } + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-custom-srid.json b/query-compiler/query-compiler/tests/data/geometry-filter-custom-srid.json new file mode 100644 index 000000000000..905ceabbf96f --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-custom-srid.json @@ -0,0 +1,21 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "near": { + "point": [1000000, 6000000], + "maxDistance": 5000, + "srid": 3857 + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-intersects.json.skip b/query-compiler/query-compiler/tests/data/geometry-filter-intersects.json.skip new file mode 100644 index 000000000000..2ad852840a36 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-intersects.json.skip @@ -0,0 +1,30 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "intersects": { + "geometry": { + "type": "Polygon", + "coordinates": [ + [ + [1, 1], + [1, 3], + [3, 3], + [3, 1], + [1, 1] + ] + ] + } + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-near.json b/query-compiler/query-compiler/tests/data/geometry-filter-near.json new file mode 100644 index 000000000000..40e3b417dc89 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-near.json @@ -0,0 +1,20 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "near": { + "point": [2.35, 48.85], + "maxDistance": 100000 + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-not-near.json b/query-compiler/query-compiler/tests/data/geometry-filter-not-near.json new file mode 100644 index 000000000000..cd6aba999fdc --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-not-near.json @@ -0,0 +1,22 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "NOT": { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 10000 + } + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-or-multiple.json b/query-compiler/query-compiler/tests/data/geometry-filter-or-multiple.json new file mode 100644 index 000000000000..a1b5f31382c2 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-or-multiple.json @@ -0,0 +1,32 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "OR": [ + { + "position": { + "near": { + "point": [0, 0], + "maxDistance": 10000 + } + } + }, + { + "position": { + "near": { + "point": [10, 10], + "maxDistance": 5000 + } + } + } + ] + } + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-filter-within.json b/query-compiler/query-compiler/tests/data/geometry-filter-within.json new file mode 100644 index 000000000000..f16d72459dcc --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-filter-within.json @@ -0,0 +1,25 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "within": { + "polygon": [ + [0, 0], + [0, 2], + [2, 2], + [2, 0], + [0, 0] + ] + } + } + } + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-multiple-orderby.json b/query-compiler/query-compiler/tests/data/geometry-multiple-orderby.json new file mode 100644 index 000000000000..78b12a1112c1 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-multiple-orderby.json @@ -0,0 +1,25 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "orderBy": [ + { + "position": { + "distanceFrom": { + "point": [0, 0], + "direction": "asc" + } + } + }, + { + "id": "desc" + } + ] + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-orderby-distance-asc.json b/query-compiler/query-compiler/tests/data/geometry-orderby-distance-asc.json new file mode 100644 index 000000000000..447381b486d3 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-orderby-distance-asc.json @@ -0,0 +1,22 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "orderBy": [ + { + "position": { + "distanceFrom": { + "point": [2.35, 48.85], + "direction": "asc" + } + } + } + ] + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-orderby-distance-desc.json b/query-compiler/query-compiler/tests/data/geometry-orderby-distance-desc.json new file mode 100644 index 000000000000..c4f622158c64 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-orderby-distance-desc.json @@ -0,0 +1,22 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "orderBy": [ + { + "position": { + "distanceFrom": { + "point": [0, 0], + "direction": "desc" + } + } + } + ] + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/geometry-orderby-with-limit.json b/query-compiler/query-compiler/tests/data/geometry-orderby-with-limit.json new file mode 100644 index 000000000000..516db8f1e054 --- /dev/null +++ b/query-compiler/query-compiler/tests/data/geometry-orderby-with-limit.json @@ -0,0 +1,23 @@ +{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "orderBy": [ + { + "position": { + "distanceFrom": { + "point": [0, 0], + "direction": "asc" + } + } + } + ], + "take": 5 + }, + "selection": { + "id": true, + "position": true + } + } +} diff --git a/query-compiler/query-compiler/tests/data/schema.prisma b/query-compiler/query-compiler/tests/data/schema.prisma index 310f7af17d45..1f02dea8f358 100644 --- a/query-compiler/query-compiler/tests/data/schema.prisma +++ b/query-compiler/query-compiler/tests/data/schema.prisma @@ -87,7 +87,8 @@ model DataTypes { model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326) + name String? + position Geometry(Point, 4326)? } model Patient { diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap new file mode 100644 index 000000000000..eaad457ca5fe --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap @@ -0,0 +1,16 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-combined-scalar-spatial.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + name: String? (name) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."name", "t0"."position" FROM "public"."Location" + AS "t0" WHERE (ST_DWithin("position"::geography, + ST_SetSRID(ST_MakePoint(0, 0), 4326)::geography, 10000) AND + "t0"."name"::text LIKE $1)» +params [const(String("Paris%"))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap new file mode 100644 index 000000000000..a8e73c6547b6 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap @@ -0,0 +1,15 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-count-with-filter.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + name: String? (name) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."name", "t0"."position" FROM "public"."Location" + AS "t0" WHERE ST_Within("position", ST_GeomFromText('POLYGON((-1 -1, -1 + 5, 5 5, 5 -1, -1 -1))', 4326)) ORDER BY "t0"."id" ASC LIMIT $1» +params [const(BigInt(10))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap new file mode 100644 index 000000000000..98d1d2c38e1d --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap @@ -0,0 +1,10 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-delete-with-filter.json +snapshot_kind: text +--- +dataMap affectedRows +execute «DELETE FROM "public"."Location" WHERE ST_DWithin("position"::geography, + ST_SetSRID(ST_MakePoint(0, 0), 4326)::geography, 1000)» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap new file mode 100644 index 000000000000..f6098f4c10c2 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap @@ -0,0 +1,15 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-filter-and-orderby.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE + ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), + 4326)::geography, 500000) ORDER BY ST_Distance(CAST("position" AS + geography), CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) ASC» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap new file mode 100644 index 000000000000..23474d954577 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-filter-and-scalar-filter.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE + (ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), + 4326)::geography, 50000) AND "t0"."id" > $1)» +params [const(BigInt(100))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap new file mode 100644 index 000000000000..5bfba168fd12 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-filter-custom-srid.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE + ST_DWithin("position", ST_SetSRID(ST_MakePoint(1000000, 6000000), 3857), + 5000)» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap new file mode 100644 index 000000000000..1a745b0ff6e5 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-filter-near.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE + ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(2.35, 48.85), + 4326)::geography, 100000)» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap new file mode 100644 index 000000000000..7a87b611fb3b --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-filter-not-near.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE + (NOT ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), + 4326)::geography, 10000))» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap new file mode 100644 index 000000000000..b7100e40150e --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap @@ -0,0 +1,15 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-filter-or-multiple.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE + (ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), + 4326)::geography, 10000) OR ST_DWithin("position"::geography, + ST_SetSRID(ST_MakePoint(10, 10), 4326)::geography, 5000))» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap new file mode 100644 index 000000000000..3b8f9993e9bc --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-filter-within.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE + ST_Within("position", ST_GeomFromText('POLYGON((0 0, 0 2, 2 2, 2 0, 0 + 0))', 4326))» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap index 0d804411f0e1..b15ef5ea3866 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-find-many.json.snap @@ -2,10 +2,11 @@ source: query-compiler/query-compiler/tests/queries.rs expression: pretty input_file: query-compiler/query-compiler/tests/data/geometry-find-many.json +snapshot_kind: text --- dataMap { id: Int (id) - position: Geometry(point) (position) + position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0"» params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap new file mode 100644 index 000000000000..adf964059a5b --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap @@ -0,0 +1,15 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-multiple-orderby.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER + BY ST_Distance(CAST("position" AS geography), + CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) ASC, "t0"."id" + DESC» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap new file mode 100644 index 000000000000..f4d717b1a431 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-orderby-distance-asc.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER + BY ST_Distance(CAST("position" AS geography), + CAST(ST_SetSRID(ST_MakePoint(2.35, 48.85), 4326) AS geography)) ASC» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap new file mode 100644 index 000000000000..c5048c5579f9 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-orderby-distance-desc.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER + BY ST_Distance(CAST("position" AS geography), + CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) DESC» +params [] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap new file mode 100644 index 000000000000..77be8cca7955 --- /dev/null +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap @@ -0,0 +1,14 @@ +--- +source: query-compiler/query-compiler/tests/queries.rs +expression: pretty +input_file: query-compiler/query-compiler/tests/data/geometry-orderby-with-limit.json +snapshot_kind: text +--- +dataMap { + id: Int (id) + position: Geometry(point)? (position) +} +query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER + BY ST_Distance(CAST("position" AS geography), + CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) ASC LIMIT $1» +params [const(BigInt(5))] From b6441178f7450e9b536e6ac7ee5f1137e8d4c25f Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sat, 21 Mar 2026 16:50:44 +0700 Subject: [PATCH 4/6] fix(postgis): address SRID None vs 0 edge cases and update operation conflicts --- psl/parser-database/src/types.rs | 6 ++- psl/schema-ast/src/parser/parse_types.rs | 8 ++- .../extractors/filters/scalar.rs | 15 ++++++ .../sql-query-builder/src/filter/visitor.rs | 36 +++++++++++++- .../query-structure/src/filter/geometry.rs | 49 ++++++++++++++++++- .../query-structure/src/order_by.rs | 13 ++++- .../fields/data_input_mapper/update.rs | 2 +- 7 files changed, 121 insertions(+), 8 deletions(-) diff --git a/psl/parser-database/src/types.rs b/psl/parser-database/src/types.rs index 48a65b10b1d0..17e48ba0b55e 100644 --- a/psl/parser-database/src/types.rs +++ b/psl/parser-database/src/types.rs @@ -270,8 +270,10 @@ impl GeometrySpec { PostgisSpatialKind::Geography => "geography", }; let subtype = self.subtype.as_str(); - let srid = self.srid.unwrap_or(0); - format!("{base}({subtype},{srid})") + match self.srid { + Some(srid) => format!("{base}({subtype},{srid})"), + None => format!("{base}({subtype})"), + } } } diff --git a/psl/schema-ast/src/parser/parse_types.rs b/psl/schema-ast/src/parser/parse_types.rs index 5fe91559a829..5d6de2393dd3 100644 --- a/psl/schema-ast/src/parser/parse_types.rs +++ b/psl/schema-ast/src/parser/parse_types.rs @@ -76,7 +76,13 @@ fn parse_geometry_type(pair: Pair<'_>, file_id: FileId) -> Result() { - Ok(v) => Some(v), + Ok(v) if v >= 0 && v <= 999_999 => Some(v), + Ok(v) => { + return Err(DatamodelError::new_validation_error( + &format!("Invalid SRID: expected a value between 0 and 999999, got {}.", v), + (file_id, srid_pair.as_span()).into(), + )); + } Err(_) => { return Err(DatamodelError::new_validation_error( "Invalid SRID: expected a valid 32-bit integer.", diff --git a/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs b/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs index acf57173d70f..5b9b48d2c9a7 100644 --- a/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs +++ b/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs @@ -189,14 +189,29 @@ impl<'a> ScalarFilterParser<'a> { // Geometry filters filters::NEAR => { + if self.reverse() { + return Err(QueryGraphBuilderError::InputError( + "Negation (NOT) is not supported for geometry 'near' filters".to_string(), + )); + } let input_map: ParsedInputMap<'_> = input.try_into()?; Ok(vec![parse_geometry_near(field, input_map)?]) } filters::WITHIN => { + if self.reverse() { + return Err(QueryGraphBuilderError::InputError( + "Negation (NOT) is not supported for geometry 'within' filters".to_string(), + )); + } let input_map: ParsedInputMap<'_> = input.try_into()?; Ok(vec![parse_geometry_within(field, input_map)?]) } filters::INTERSECTS => { + if self.reverse() { + return Err(QueryGraphBuilderError::InputError( + "Negation (NOT) is not supported for geometry 'intersects' filters".to_string(), + )); + } let input_map: ParsedInputMap<'_> = input.try_into()?; Ok(vec![parse_geometry_intersects(field, input_map)?]) } diff --git a/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs b/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs index 90ed1acae443..86c1f344ed6c 100644 --- a/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs +++ b/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs @@ -659,6 +659,39 @@ impl FilterVisitorExt for FilterVisitor { GeometryFilterCondition::Intersects { geometry, .. } => { let geom_type = geometry.get("type").and_then(|v| v.as_str()).unwrap_or(""); let wkt = match geom_type { + "Point" => { + if let Some(coords) = geometry.get("coordinates").and_then(|v| v.as_array()) { + if coords.len() >= 2 { + format!("POINT({} {})", + coords[0].as_f64().unwrap_or(0.0), + coords[1].as_f64().unwrap_or(0.0)) + } else { + panic!("Invalid Point coordinates: expected at least 2 values, got {}", coords.len()) + } + } else { + panic!("Invalid Point geometry: missing or invalid 'coordinates' array") + } + } + "LineString" => { + if let Some(coords) = geometry.get("coordinates").and_then(|v| v.as_array()) { + let point_strs: Vec = coords.iter() + .filter_map(|p| { + p.as_array().and_then(|arr| { + if arr.len() >= 2 { + Some(format!("{} {}", + arr[0].as_f64().unwrap_or(0.0), + arr[1].as_f64().unwrap_or(0.0))) + } else { + None + } + }) + }) + .collect(); + format!("LINESTRING({})", point_strs.join(", ")) + } else { + panic!("Invalid LineString geometry: missing or invalid 'coordinates' array") + } + } "Polygon" => { if let Some(coords) = geometry.get("coordinates").and_then(|v| v.as_array()) { let ring_strs: Vec = coords.iter() @@ -686,7 +719,8 @@ impl FilterVisitorExt for FilterVisitor { "POLYGON EMPTY".to_string() } } - _ => format!("POINT(0 0)"), + "" => panic!("Missing 'type' field in GeoJSON geometry"), + unsupported => panic!("Unsupported GeoJSON geometry type '{}' for intersects filter. Supported types: Point, LineString, Polygon", unsupported), }; let escaped_wkt = wkt.replace('\'', "''"); format!( diff --git a/query-compiler/query-structure/src/filter/geometry.rs b/query-compiler/query-structure/src/filter/geometry.rs index dffde355cdf8..0ce7007bf21a 100644 --- a/query-compiler/query-structure/src/filter/geometry.rs +++ b/query-compiler/query-structure/src/filter/geometry.rs @@ -1,12 +1,18 @@ use crate::*; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct GeometryFilter { pub field: ScalarFieldRef, pub condition: GeometryFilterCondition, } -#[derive(Debug, Clone, PartialEq)] +impl PartialEq for GeometryFilter { + fn eq(&self, other: &Self) -> bool { + self.field == other.field && self.condition == other.condition + } +} + +#[derive(Debug, Clone)] pub enum GeometryFilterCondition { Near { point: (f64, f64), @@ -23,6 +29,45 @@ pub enum GeometryFilterCondition { }, } +impl PartialEq for GeometryFilterCondition { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + ( + GeometryFilterCondition::Near { + point: p1, + max_distance: d1, + srid: s1, + }, + GeometryFilterCondition::Near { + point: p2, + max_distance: d2, + srid: s2, + }, + ) => { + p1.0.to_bits() == p2.0.to_bits() + && p1.1.to_bits() == p2.1.to_bits() + && d1.to_bits() == d2.to_bits() + && s1 == s2 + } + ( + GeometryFilterCondition::Within { polygon: poly1, srid: s1 }, + GeometryFilterCondition::Within { polygon: poly2, srid: s2 }, + ) => { + s1 == s2 + && poly1.len() == poly2.len() + && poly1.iter().zip(poly2.iter()).all(|((x1, y1), (x2, y2))| { + x1.to_bits() == x2.to_bits() && y1.to_bits() == y2.to_bits() + }) + } + ( + GeometryFilterCondition::Intersects { geometry: g1, srid: s1 }, + GeometryFilterCondition::Intersects { geometry: g2, srid: s2 }, + ) => s1 == s2 && g1 == g2, + _ => false, + } + } +} + impl std::hash::Hash for GeometryFilter { fn hash(&self, state: &mut H) { self.field.hash(state); diff --git a/query-compiler/query-structure/src/order_by.rs b/query-compiler/query-structure/src/order_by.rs index 5994599f377c..b7dfaea09f9d 100644 --- a/query-compiler/query-structure/src/order_by.rs +++ b/query-compiler/query-structure/src/order_by.rs @@ -231,7 +231,7 @@ pub struct OrderByRelevance { pub path: Vec, } -#[derive(Clone, PartialEq)] +#[derive(Clone)] pub struct OrderByGeometry { pub field: ScalarFieldRef, pub path: Vec, @@ -240,6 +240,17 @@ pub struct OrderByGeometry { pub srid: Option, } +impl PartialEq for OrderByGeometry { + fn eq(&self, other: &Self) -> bool { + self.field == other.field + && self.path == other.path + && self.point.0.to_bits() == other.point.0.to_bits() + && self.point.1.to_bits() == other.point.1.to_bits() + && self.sort_order == other.sort_order + && self.srid == other.srid + } +} + impl std::fmt::Debug for OrderByGeometry { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OrderByGeometry") diff --git a/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs b/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs index a132e3e78390..5a2c755d6600 100644 --- a/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs +++ b/query-compiler/schema/src/build/input_types/fields/data_input_mapper/update.rs @@ -46,7 +46,7 @@ impl DataInputFieldMapper for UpdateDataInputFieldMapper { TypeIdentifier::UUID => InputType::object(update_operations_object_type(ctx, "Uuid", sf.clone(), false)), TypeIdentifier::Bytes => InputType::object(update_operations_object_type(ctx, "Bytes", sf.clone(), false)), TypeIdentifier::Geometry(_) => { - InputType::object(update_operations_object_type(ctx, "Bytes", sf.clone(), false)) + InputType::object(update_operations_object_type(ctx, "Geometry", sf.clone(), false)) } TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), From c90f1ee220688a7138843e270006d6ad7551de6e Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Fri, 22 May 2026 23:49:59 +0700 Subject: [PATCH 5/6] feat(postgis): add Geometry/Geography scalars with spatial filters while preserving native-attribute parity with VarChar/Decimal --- psl/parser-database/src/attributes.rs | 2 +- psl/parser-database/src/attributes/default.rs | 20 +- psl/parser-database/src/types.rs | 98 ++-- psl/psl-core/src/builtin_connectors/mod.rs | 4 +- .../postgres_datamodel_connector.rs | 112 ++++- .../native_types.rs | 142 +++++- psl/psl-core/src/datamodel_connector.rs | 12 +- .../validation_pipeline/validations/fields.rs | 67 ++- .../postgis_geometry_keyword_valid.prisma | 17 + ...ostgis_native_type_keyword_mismatch.prisma | 27 ++ psl/schema-ast/src/ast.rs | 2 +- psl/schema-ast/src/ast/field.rs | 49 +- psl/schema-ast/src/parser/datamodel.pest | 23 +- psl/schema-ast/src/parser/parse_types.rs | 43 -- quaint/.github/workflows/test.yml | 42 +- quaint/src/ast/function.rs | 10 + quaint/src/ast/function/postgis.rs | 104 +++++ quaint/src/visitor.rs | 16 + .../tests/geometry-filters-graph-builds.rs | 307 +++++++++++- .../tests/geometry_find_many_graph_builds.rs | 5 +- .../core/src/query_document/parser.rs | 1 - .../extractors/filters/scalar.rs | 99 +++- .../extractors/query_arguments.rs | 86 +++- query-compiler/dmmf/Cargo.toml | 2 +- .../src/ast_builders/datamodel_ast_builder.rs | 36 +- .../schema_ast_builder/type_renderer.rs | 2 +- query-compiler/dmmf/src/tests/tests.rs | 54 ++- .../sql-query-builder/src/cursor_condition.rs | 4 +- .../sql-query-builder/src/filter/visitor.rs | 179 +++---- .../sql-query-builder/src/ordering.rs | 22 +- .../query-compiler/src/data_mapper.rs | 30 +- .../query-compiler/tests/data/schema.prisma | 2 +- ...geometry-combined-scalar-spatial.json.snap | 9 +- ...eries@geometry-count-with-filter.json.snap | 7 +- ...ries@geometry-delete-with-filter.json.snap | 8 +- ...ries@geometry-filter-and-orderby.json.snap | 11 +- ...eometry-filter-and-scalar-filter.json.snap | 8 +- ...ries@geometry-filter-custom-srid.json.snap | 7 +- ...es__queries@geometry-filter-near.json.snap | 6 +- ...queries@geometry-filter-not-near.json.snap | 7 +- ...ries@geometry-filter-or-multiple.json.snap | 10 +- ...__queries@geometry-filter-within.json.snap | 6 +- ...ueries@geometry-multiple-orderby.json.snap | 8 +- ...es@geometry-orderby-distance-asc.json.snap | 7 +- ...s@geometry-orderby-distance-desc.json.snap | 7 +- ...ries@geometry-orderby-with-limit.json.snap | 7 +- .../query-structure/src/field/mod.rs | 25 +- .../query-structure/src/field/scalar.rs | 33 +- .../query-structure/src/filter/geojson.rs | 437 ++++++++++++++++++ .../query-structure/src/filter/geometry.rs | 49 +- .../query-structure/src/filter/mod.rs | 2 + .../query-structure/src/prisma_value_ext.rs | 2 +- .../input_types/fields/field_filter_types.rs | 21 +- .../schema/src/build/input_types/mod.rs | 2 +- .../input_types/objects/order_by_objects.rs | 14 +- .../schema/src/build/output_types/field.rs | 2 +- query-compiler/schema/src/output_types.rs | 4 +- query-compiler/schema/src/query_schema.rs | 8 +- .../mongodb-query-connector/src/filter.rs | 5 + .../src/flavour/postgres/renderer.rs | 7 + .../src/flavour/postgres/schema_differ.rs | 19 +- .../introspection_pair/scalar_field.rs | 30 +- .../src/sql_doc_parser.rs | 25 +- .../src/sql_schema_calculator.rs | 35 +- .../tests/postgres/postgis_geometry.rs | 55 ++- .../migrations/postgres/postgis_geometry.rs | 86 +++- .../sql-schema-describer/src/postgres.rs | 44 +- 67 files changed, 2010 insertions(+), 622 deletions(-) create mode 100644 psl/psl/tests/validation/postgres/postgis_geometry_keyword_valid.prisma create mode 100644 psl/psl/tests/validation/postgres/postgis_native_type_keyword_mismatch.prisma create mode 100644 quaint/src/ast/function/postgis.rs create mode 100644 query-compiler/query-structure/src/filter/geojson.rs diff --git a/psl/parser-database/src/attributes.rs b/psl/parser-database/src/attributes.rs index 271f27d54576..b823175517e1 100644 --- a/psl/parser-database/src/attributes.rs +++ b/psl/parser-database/src/attributes.rs @@ -57,7 +57,7 @@ fn resolve_composite_type_attributes<'db>( ctx.visit_attributes((ctid.0, (ctid.1, field_id))); - if let ScalarFieldType::BuiltInScalar(_scalar_type) = r#type { + if let ScalarFieldType::BuiltInScalar(_) = r#type { // native type attributes if let Some((datasource_name, type_name, args)) = ctx.visit_datasource_scoped() { native_types::visit_composite_type_field_native_type_attribute( diff --git a/psl/parser-database/src/attributes/default.rs b/psl/parser-database/src/attributes/default.rs index 2f9795556e87..1e39379d9e9a 100644 --- a/psl/parser-database/src/attributes/default.rs +++ b/psl/parser-database/src/attributes/default.rs @@ -73,9 +73,6 @@ pub(super) fn visit_model_field_default( "Only @default(dbgenerated(\"...\")) can be used for Unsupported types.", ); } - ScalarFieldType::Geometry(_) => { - ctx.push_attribute_validation_error("Only @default(dbgenerated(\"...\")) can be used for Geometry types."); - } } } @@ -142,9 +139,6 @@ pub(super) fn visit_composite_field_default( ScalarFieldType::Unsupported(_) => { ctx.push_attribute_validation_error("Composite field of type `Unsupported` cannot have default values.") } - ScalarFieldType::Geometry(_) => { - ctx.push_attribute_validation_error("Composite field of type `Geometry` cannot have default values.") - } } } @@ -197,6 +191,13 @@ fn validate_model_builtin_scalar_type_default( field_id: (crate::ModelId, ast::FieldId), ctx: &mut Context<'_>, ) { + // PostGIS spatial scalars only accept `@default(dbgenerated(...))` (already handled by the + // caller above). Reject everything else with a stable error matching the previous behaviour. + if matches!(scalar_type, ScalarType::Geometry | ScalarType::Geography) { + ctx.push_attribute_validation_error("Only @default(dbgenerated(\"...\")) can be used for Geometry types."); + return; + } + let arity = ctx.asts[field_id.0][field_id.1].arity; match (scalar_type, value) { // Functions @@ -257,6 +258,13 @@ fn validate_composite_builtin_scalar_type_default( field_arity: ast::FieldArity, ctx: &mut Context<'_>, ) { + // PostGIS spatial scalars cannot have defaults on composite fields (mirrors the previous + // top-level rejection arm that lived on `ScalarFieldType::Geometry`). + if matches!(scalar_type, ScalarType::Geometry | ScalarType::Geography) { + ctx.push_attribute_validation_error("Composite field of type `Geometry` cannot have default values."); + return; + } + match (scalar_type, value) { // Functions (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_ULID => { diff --git a/psl/parser-database/src/types.rs b/psl/parser-database/src/types.rs index 17e48ba0b55e..8947bd3527ed 100644 --- a/psl/parser-database/src/types.rs +++ b/psl/parser-database/src/types.rs @@ -11,7 +11,7 @@ use schema_ast::ast::{self, EnumValueId, WithName}; use serde::{Deserialize, Serialize}; use std::{ collections::BTreeMap, - fmt::{self, Write as _}, + fmt, }; pub(super) fn resolve_types(ctx: &mut Context<'_>) { @@ -188,7 +188,9 @@ impl UnsupportedType { } } -/// OGC / PostGIS geometry subtype for [`ScalarFieldType::Geometry`]. +/// OGC / PostGIS geometry subtype carried by `PostgresType::Postgis(...)` and surfaced via the +/// `@db.Geometry(...)` / `@db.Geography(...)` native attributes. The PSL keyword side uses the +/// unit [`ScalarType::Geometry`] / [`ScalarType::Geography`] variants. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum GeometrySubtype { /// `POINT` subtype. @@ -225,21 +227,6 @@ impl GeometrySubtype { } } -impl From for GeometrySubtype { - fn from(s: ast::GeometrySubtype) -> Self { - match s { - ast::GeometrySubtype::Point => Self::Point, - ast::GeometrySubtype::LineString => Self::LineString, - ast::GeometrySubtype::Polygon => Self::Polygon, - ast::GeometrySubtype::MultiPoint => Self::MultiPoint, - ast::GeometrySubtype::MultiLineString => Self::MultiLineString, - ast::GeometrySubtype::MultiPolygon => Self::MultiPolygon, - ast::GeometrySubtype::GeometryCollection => Self::GeometryCollection, - ast::GeometrySubtype::Geometry => Self::Geometry, - } - } -} - /// PostGIS base type for a [`GeometrySpec`] (`geometry` vs `geography` in PostgreSQL). #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)] pub enum PostgisSpatialKind { @@ -263,16 +250,31 @@ pub struct GeometrySpec { } impl GeometrySpec { + /// PSL scalar type name as it appears in the schema language: either `Geometry` or + /// `Geography`. The casing matches the keyword the user writes in the schema. + pub fn psl_type_name(&self) -> &'static str { + match self.spatial { + PostgisSpatialKind::Geometry => "Geometry", + PostgisSpatialKind::Geography => "Geography", + } + } + /// SQL column type for PostgreSQL / PostGIS (e.g. `geometry(Point,4326)` or `geography(Point,4326)`). pub fn postgres_sql_type(&self) -> String { let base = match self.spatial { PostgisSpatialKind::Geometry => "geometry", PostgisSpatialKind::Geography => "geography", }; + // PostGIS rejects `geometry(Geometry)` as a column type — the unconstrained form is + // simply `geometry` (or `geography`). Only emit the parameter list when a concrete + // subtype or SRID is specified. + let bare_subtype = self.subtype == GeometrySubtype::Geometry; let subtype = self.subtype.as_str(); - match self.srid { - Some(srid) => format!("{base}({subtype},{srid})"), - None => format!("{base}({subtype})"), + match (self.srid, bare_subtype) { + (None, true) => base.to_owned(), + (None, false) => format!("{base}({subtype})"), + (Some(srid), true) => format!("{base}(Geometry,{srid})"), + (Some(srid), false) => format!("{base}({subtype},{srid})"), } } } @@ -288,8 +290,6 @@ pub enum ScalarFieldType { Extension(ExtensionTypeId), /// A Prisma scalar type BuiltInScalar(ScalarType), - /// PostGIS-style `Geometry(Point, 4326)` scalar - Geometry(GeometrySpec), /// An `Unsupported("...")` type Unsupported(UnsupportedType), } @@ -372,9 +372,21 @@ impl ScalarFieldType { matches!(self, Self::BuiltInScalar(ScalarType::Decimal)) } - /// True if the field's type is `Geometry(...)`. + /// True if the field's type is `Geometry` or `Geography` (any PostGIS spatial scalar). pub fn is_geometry(self) -> bool { - matches!(self, Self::Geometry(_)) + matches!( + self, + Self::BuiltInScalar(ScalarType::Geometry) | Self::BuiltInScalar(ScalarType::Geography) + ) + } + + /// PostGIS spatial kind discriminator (`Geometry` vs `Geography`) when the field type is a + /// PostGIS spatial scalar; `None` for any other field type. + pub fn postgis_spatial_kind(self) -> Option { + match self { + Self::BuiltInScalar(scalar) => scalar.postgis_spatial_kind(), + _ => None, + } } /// Display the field type as it would appear in the Prisma schema. @@ -407,13 +419,6 @@ impl fmt::Display for DisplayScalarFieldType<'_> { .expect("extension type id to have a name"); write!(f, "{}", self.db.interner.get(*name).unwrap()) } - ScalarFieldType::Geometry(spec) => { - write!(f, "Geometry({}", spec.subtype.as_str())?; - if let Some(srid) = spec.srid { - write!(f, ", {srid}")?; - } - f.write_char(')') - } ScalarFieldType::Unsupported(ut) => { write!(f, "Unsupported(\"{}\")", self.db.interner.get(ut.name).unwrap()) } @@ -948,13 +953,6 @@ fn visit_enum<'db>(enm: &'db ast::Enum, ctx: &mut Context<'db>) { /// does not match any we know of. fn field_type<'db>(field: &'db ast::Field, ctx: &mut Context<'db>) -> Result { match &field.field_type { - ast::FieldType::Geometry { subtype, srid, .. } => { - Ok(FieldType::Scalar(ScalarFieldType::Geometry(GeometrySpec { - subtype: (*subtype).into(), - srid: *srid, - spatial: PostgisSpatialKind::Geometry, - }))) - } ast::FieldType::Unsupported(name, _) => { let unsupported = UnsupportedType::new(ctx.interner.intern(name)); Ok(FieldType::Scalar(ScalarFieldType::Unsupported(unsupported))) @@ -966,6 +964,7 @@ fn field_type<'db>(field: &'db ast::Field, ctx: &mut Context<'db>) -> Result "Json", ScalarType::Bytes => "Bytes", ScalarType::Decimal => "Decimal", + ScalarType::Geometry => "Geometry", + ScalarType::Geography => "Geography", + } + } + + /// PostGIS spatial kind (`Geometry` vs `Geography`) for the two PostGIS-flavored scalar + /// variants; `None` for any non-spatial scalar. + pub fn postgis_spatial_kind(&self) -> Option { + match self { + ScalarType::Geometry => Some(PostgisSpatialKind::Geometry), + ScalarType::Geography => Some(PostgisSpatialKind::Geography), + _ => None, } } @@ -1712,6 +1730,8 @@ impl ScalarType { "json" => Some(ScalarType::Json), "bytes" => Some(ScalarType::Bytes), "decimal" => Some(ScalarType::Decimal), + "geometry" => Some(ScalarType::Geometry), + "geography" => Some(ScalarType::Geography), _ => None, }, _ => match s { @@ -1724,6 +1744,8 @@ impl ScalarType { "Json" => Some(ScalarType::Json), "Bytes" => Some(ScalarType::Bytes), "Decimal" => Some(ScalarType::Decimal), + "Geometry" => Some(ScalarType::Geometry), + "Geography" => Some(ScalarType::Geography), _ => None, }, } diff --git a/psl/psl-core/src/builtin_connectors/mod.rs b/psl/psl-core/src/builtin_connectors/mod.rs index d8e943a68c4c..5b7f72ff44b8 100644 --- a/psl/psl-core/src/builtin_connectors/mod.rs +++ b/psl/psl-core/src/builtin_connectors/mod.rs @@ -11,7 +11,9 @@ pub use mssql_datamodel_connector::{MsSqlType, MsSqlTypeParameter}; #[cfg(feature = "mysql")] pub use mysql_datamodel_connector::MySqlType; #[cfg(feature = "postgresql")] -pub use postgres_datamodel_connector::{KnownPostgresType, PostgresDatasourceProperties, PostgresType}; +pub use postgres_datamodel_connector::{ + GeometryNativeArgs, KnownPostgresType, PostgisNativeType, PostgresDatasourceProperties, PostgresType, +}; mod capabilities_support; #[cfg(feature = "mongodb")] diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs index 93c52ecd9221..ea2ca8b7afaa 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs @@ -2,14 +2,14 @@ mod datasource; mod native_types; mod validations; -pub use native_types::{KnownPostgresType, PostgresType}; -use parser_database::{ExtensionTypes, GeometrySpec, ScalarFieldType}; +pub use native_types::{GeometryNativeArgs, KnownPostgresType, PostgisNativeType, PostgresType}; +use parser_database::{ExtensionTypes, GeometrySpec, PostgisSpatialKind, ScalarFieldType}; use crate::{ Configuration, Datasource, DatasourceConnectorData, PreviewFeature, ValidatedSchema, datamodel_connector::{ - Connector, ConnectorCapabilities, ConnectorCapability, ConstraintScope, Flavour, NativeTypeConstructor, - NativeTypeInstance, NativeTypeParseError, RelationMode, StringFilter, + AllowedType, Connector, ConnectorCapabilities, ConnectorCapability, ConstraintScope, Flavour, + NativeTypeConstructor, NativeTypeInstance, NativeTypeParseError, RelationMode, StringFilter, }, diagnostics::Diagnostics, parser_database::{IndexAlgorithm, OperatorClass, ParserDatabase, ReferentialAction, ScalarType, ast, walkers}, @@ -18,7 +18,7 @@ use KnownPostgresType::*; use chrono::*; use enumflags2::BitFlags; use lsp_types::{CompletionItem, CompletionItemKind, CompletionList, InsertTextFormat}; -use std::{borrow::Cow, collections::HashMap, sync::Arc}; +use std::{borrow::Cow, collections::HashMap, sync::Arc, sync::LazyLock}; use super::completions; @@ -83,9 +83,33 @@ pub struct PostgresDatamodelConnector; const DATE_TIME_DEFAULT: KnownPostgresType = KnownPostgresType::Timestamp(Some(3)); const BYTES_DEFAULT: KnownPostgresType = KnownPostgresType::ByteA; -fn geometry_sql_column_type(spec: &GeometrySpec) -> String { - spec.postgres_sql_type() -} +/// Constructor list exposed to LSP/validation: the macro-generated scalar variants plus the +/// PostGIS `@db.Geometry(...)` / `@db.Geography(...)` entries assembled on first access. +static EXTENDED_CONSTRUCTORS: LazyLock> = LazyLock::new(|| { + // `AllowedType::field_type` is matched via `PartialEq` against the declared scalar field + // type. Each native attribute (`@db.Geometry` / `@db.Geography`) only matches the + // corresponding PSL scalar (`Geometry` / `Geography`), enforcing the pairing between the + // PSL keyword and the native attribute exactly like every other parametrized scalar + // (`String @db.VarChar(...)`, `Decimal @db.Decimal(p,s)`, ...). + let mut all = native_types::CONSTRUCTORS.to_vec(); + all.push(NativeTypeConstructor { + name: Cow::Borrowed("Geometry"), + number_of_args: 1, + number_of_optional_args: 1, + allowed_types: Cow::Owned(vec![AllowedType::plain(ScalarFieldType::BuiltInScalar( + ScalarType::Geometry, + ))]), + }); + all.push(NativeTypeConstructor { + name: Cow::Borrowed("Geography"), + number_of_args: 1, + number_of_optional_args: 1, + allowed_types: Cow::Owned(vec![AllowedType::plain(ScalarFieldType::BuiltInScalar( + ScalarType::Geography, + ))]), + }); + all +}); const SCALAR_TYPE_DEFAULTS: &[(ScalarType, KnownPostgresType)] = &[ (ScalarType::Int, KnownPostgresType::Integer), @@ -318,6 +342,17 @@ impl Connector for PostgresDatamodelConnector { .get_by_db_name_and_modifiers(name, Some(modifiers)) .map(|e| ScalarFieldType::Extension(e.id)); } + PostgresType::Postgis(postgis) => { + // Map the PostGIS spatial kind back to the matching PSL keyword. The structured + // subtype/SRID information lives in the native attribute itself, so introspection + // round-trips by attaching the `PostgresType::Postgis` instance to the field; the + // ScalarFieldType here only carries the keyword discriminator. + let scalar = match postgis.to_geometry_spec().spatial { + PostgisSpatialKind::Geometry => ScalarType::Geometry, + PostgisSpatialKind::Geography => ScalarType::Geography, + }; + return Some(ScalarFieldType::BuiltInScalar(scalar)); + } }; let res = match native_type { @@ -366,6 +401,23 @@ impl Connector for PostgresDatamodelConnector { schema: &ValidatedSchema, ) -> Option { let native_type = match scalar_type { + // PostGIS spatial scalars are not in `SCALAR_TYPE_DEFAULTS` (they have no + // `KnownPostgresType` row); fall back to the unconstrained `geometry` / `geography` + // column type so introspection can still surface a `@db.Geometry(...)` / `@db.Geography(...)` + // attribute when the user supplies one. + ScalarFieldType::BuiltInScalar(spatial @ (ScalarType::Geometry | ScalarType::Geography)) => { + let args = GeometryNativeArgs { + subtype: parser_database::GeometrySubtype::Geometry, + srid: None, + }; + let postgis = match spatial { + ScalarType::Geometry => PostgisNativeType::Geometry(args), + ScalarType::Geography => PostgisNativeType::Geography(args), + _ => unreachable!("matched only Geometry|Geography above"), + }; + let native_type = PostgresType::Postgis(postgis); + return Some(NativeTypeInstance::new::(native_type)); + } ScalarFieldType::BuiltInScalar(scalar_type) => PostgresType::Known( *SCALAR_TYPE_DEFAULTS .iter() @@ -379,10 +431,6 @@ impl Connector for PostgresDatamodelConnector { let native_type = PostgresType::Unknown(name.to_owned(), modifiers.to_vec()); return Some(NativeTypeInstance::new::(native_type)); } - ScalarFieldType::Geometry(spec) => { - let native_type = PostgresType::Unknown(geometry_sql_column_type(spec), Vec::new()); - return Some(NativeTypeInstance::new::(native_type)); - } ScalarFieldType::CompositeType(_) | ScalarFieldType::Enum(_) | ScalarFieldType::Unsupported(_) => { return None; } @@ -398,7 +446,21 @@ impl Connector for PostgresDatamodelConnector { span: ast::Span, errors: &mut Diagnostics, ) { - let PostgresType::Known(native_type) = native_type_instance.downcast_ref() else { + let postgres_type: &PostgresType = native_type_instance.downcast_ref(); + + // Validate PostGIS SRID separately because it is enforced through GeometryNativeArgs + // rather than the macro-generated KnownPostgresType arguments. + if let PostgresType::Postgis(postgis) = postgres_type { + let error = self.native_instance_error(native_type_instance); + if let Some(srid) = postgis.args().srid + && !(0..=999_999).contains(&srid) + { + errors.push_error(error.new_argument_m_out_of_range_error("SRID must be between 0 and 999999.", span)); + } + return; + } + + let PostgresType::Known(native_type) = postgres_type else { return; }; let error = self.native_instance_error(native_type_instance); @@ -457,7 +519,9 @@ impl Connector for PostgresDatamodelConnector { } fn available_native_type_constructors(&self) -> &'static [NativeTypeConstructor] { - native_types::CONSTRUCTORS + // The macro-generated CONSTRUCTORS only covers built-in scalar variants. Merge the + // PostGIS entries on top so prisma-fmt completions surface @db.Geometry / @db.Geography. + &EXTENDED_CONSTRUCTORS } fn supported_index_types(&self) -> BitFlags { @@ -477,6 +541,19 @@ impl Connector for PostgresDatamodelConnector { span: ast::Span, diagnostics: &mut Diagnostics, ) -> Option { + // Intercept the structured PostGIS native attributes before the macro-generated path so + // they go through GeometryNativeArgs parsing (subtype enum + bounded SRID) rather than + // the fallback "unknown" branch. + if let Some(parsed) = native_types::try_parse_postgis(name, args) { + return match parsed { + Ok(postgis) => Some(NativeTypeInstance::new(PostgresType::Postgis(postgis))), + Err(err) => { + diagnostics.push_error(err.into_datamodel_error(span)); + None + } + }; + } + let nt = match KnownPostgresType::from_parts(name, args) { Ok(res) => PostgresType::Known(res), Err(NativeTypeParseError::UnknownType { .. }) => PostgresType::Unknown(name.to_owned(), args.to_owned()), @@ -488,6 +565,13 @@ impl Connector for PostgresDatamodelConnector { Some(NativeTypeInstance::new(nt)) } + fn geometry_spec_for_native_type(&self, instance: &NativeTypeInstance) -> Option { + instance + .downcast_ref::() + .as_postgis() + .map(|p| p.to_geometry_spec()) + } + fn native_type_to_parts<'t>(&self, native_type: &'t NativeTypeInstance) -> (&'t str, Cow<'t, [String]>) { native_type.downcast_ref::().to_parts() } diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector/native_types.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector/native_types.rs index 166d9602e442..bc8b61036ed5 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector/native_types.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector/native_types.rs @@ -1,9 +1,19 @@ use std::borrow::Cow; +use parser_database::{GeometrySpec, GeometrySubtype, PostgisSpatialKind}; + +use crate::datamodel_connector::{NativeTypeArguments, NativeTypeParseError}; + #[derive(Debug, Clone, PartialEq)] pub enum PostgresType { Known(KnownPostgresType), Unknown(String, Vec), + /// PostGIS spatial types (`geometry(Subtype, SRID)` / `geography(Subtype, SRID)`) that the + /// generic `native_type_definition!` macro cannot express because their arguments mix an + /// enum subtype with an optional SRID. The matching PSL scalars are the unit + /// `ScalarType::Geometry` / `ScalarType::Geography` variants — subtype and SRID are carried + /// here on the native attribute (same convention as `String @db.VarChar(300)`). + Postgis(PostgisNativeType), } impl PostgresType { @@ -11,17 +21,147 @@ impl PostgresType { match self { Self::Known(known) => known.to_parts(), Self::Unknown(name, args) => (name.as_str(), Cow::Borrowed(args)), + Self::Postgis(postgis) => postgis.rendered_parts(), } } pub fn as_known(&self) -> Option<&KnownPostgresType> { match self { Self::Known(known) => Some(known), - Self::Unknown(_, _) => None, + Self::Unknown(_, _) | Self::Postgis(_) => None, + } + } + + pub fn as_postgis(&self) -> Option<&PostgisNativeType> { + match self { + Self::Postgis(postgis) => Some(postgis), + _ => None, + } + } +} + +/// PostGIS native attribute: either `@db.Geometry(subtype, srid?)` for the planar `geometry` +/// type or `@db.Geography(subtype, srid?)` for the geodetic `geography` type. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum PostgisNativeType { + Geometry(GeometryNativeArgs), + Geography(GeometryNativeArgs), +} + +impl PostgisNativeType { + pub fn args(&self) -> &GeometryNativeArgs { + match self { + Self::Geometry(args) | Self::Geography(args) => args, + } + } + + pub fn spatial(&self) -> PostgisSpatialKind { + match self { + Self::Geometry(_) => PostgisSpatialKind::Geometry, + Self::Geography(_) => PostgisSpatialKind::Geography, + } + } + + pub fn to_geometry_spec(&self) -> GeometrySpec { + let args = self.args(); + GeometrySpec { + subtype: args.subtype, + srid: args.srid, + spatial: self.spatial(), + } + } + + fn rendered_parts(self) -> (&'static str, Cow<'static, [String]>) { + let (name, args) = match self { + Self::Geometry(args) => ("Geometry", args), + Self::Geography(args) => ("Geography", args), + }; + (name, Cow::Owned(args.to_parts())) + } +} + +/// Arguments accepted by `@db.Geometry` / `@db.Geography`: a required OGC subtype optionally +/// followed by a non-negative SRID literal. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct GeometryNativeArgs { + pub subtype: GeometrySubtype, + pub srid: Option, +} + +impl GeometryNativeArgs { + pub fn to_parts(&self) -> Vec { + let mut out = vec![self.subtype.as_str().to_owned()]; + if let Some(srid) = self.srid { + out.push(srid.to_string()); } + out } } +impl NativeTypeArguments for GeometryNativeArgs { + const DESCRIPTION: &'static str = + "an OGC geometry subtype (e.g. Point, LineString, Polygon) optionally followed by an SRID between 0 and 999999"; + const REQUIRED_ARGUMENTS_COUNT: usize = 1; + const OPTIONAL_ARGUMENTS_COUNT: usize = 1; + + fn from_parts(parts: &[String]) -> Option { + match parts { + [subtype] => parse_subtype(subtype).map(|subtype| Self { subtype, srid: None }), + [subtype, srid] => { + let subtype = parse_subtype(subtype)?; + let srid = srid.parse::().ok().filter(|v| (0..=999_999).contains(v))?; + Some(Self { + subtype, + srid: Some(srid), + }) + } + _ => None, + } + } + + fn to_parts(&self) -> Vec { + GeometryNativeArgs::to_parts(self) + } +} + +fn parse_subtype(name: &str) -> Option { + match name { + "Point" => Some(GeometrySubtype::Point), + "LineString" => Some(GeometrySubtype::LineString), + "Polygon" => Some(GeometrySubtype::Polygon), + "MultiPoint" => Some(GeometrySubtype::MultiPoint), + "MultiLineString" => Some(GeometrySubtype::MultiLineString), + "MultiPolygon" => Some(GeometrySubtype::MultiPolygon), + "GeometryCollection" => Some(GeometrySubtype::GeometryCollection), + "Geometry" => Some(GeometrySubtype::Geometry), + _ => None, + } +} + +/// Parse a `@db.Geometry(...)` / `@db.Geography(...)` invocation. Returns `Some` only when the +/// name and argument shape match; the caller falls back to the generic +/// `KnownPostgresType::from_parts` path for every other native type. +pub(crate) fn try_parse_postgis<'a>( + name: &'a str, + arguments: &[String], +) -> Option>> { + let ctor = match name { + "Geometry" => PostgisNativeType::Geometry, + "Geography" => PostgisNativeType::Geography, + _ => return None, + }; + + let Some(args) = GeometryNativeArgs::from_parts(arguments) else { + let rendered_args = format!("({})", arguments.join(", ")); + return Some(Err(NativeTypeParseError::InvalidArgs { + expected: GeometryNativeArgs::DESCRIPTION, + found: rendered_args, + })); + }; + + Some(Ok(ctor(args))) +} + crate::native_type_definition! { KnownPostgresType; SmallInt -> Int, diff --git a/psl/psl-core/src/datamodel_connector.rs b/psl/psl-core/src/datamodel_connector.rs index 325b1a96a20e..c6dd93367b1e 100644 --- a/psl/psl-core/src/datamodel_connector.rs +++ b/psl/psl-core/src/datamodel_connector.rs @@ -30,7 +30,7 @@ use diagnostics::{DatamodelError, Diagnostics, NativeTypeErrorFactory, Span}; use enumflags2::BitFlags; use lsp_types::CompletionList; use parser_database::{ - ExtensionTypes, IndexAlgorithm, ParserDatabase, ReferentialAction, ScalarFieldType, ScalarType, + ExtensionTypes, GeometrySpec, IndexAlgorithm, ParserDatabase, ReferentialAction, ScalarFieldType, ScalarType, ast::{self, SchemaPosition}, walkers, }; @@ -213,6 +213,16 @@ pub trait Connector: Send + Sync { diagnostics: &mut Diagnostics, ) -> Option; + /// If this native type expresses a PostGIS spatial column (`@db.Geometry(...)` / + /// `@db.Geography(...)`), return the equivalent `GeometrySpec`. Other connectors keep + /// the default `None` implementation. + /// + /// Used by query-structure (and downstream SQL generation) to derive the spatial + /// kind, subtype and SRID from the schema instead of guessing from the SRID value. + fn geometry_spec_for_native_type(&self, _instance: &NativeTypeInstance) -> Option { + None + } + fn native_type_supports_compacting(&self, _: Option) -> bool { true } diff --git a/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs b/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs index a82f3fbc2273..d5aafbb8d4f5 100644 --- a/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs +++ b/psl/psl-core/src/validate/validation_pipeline/validations/fields.rs @@ -10,7 +10,7 @@ use crate::datamodel_connector::{ConnectorCapability, NativeTypeConstructor, wal use crate::{diagnostics::DatamodelError, validate::validation_pipeline::context::Context}; use itertools::Itertools; use parser_database::{ - GeometrySpec, ScalarFieldType, ScalarType, + ScalarFieldType, ScalarType, ast::{self, WithSpan}, walkers::{ CompositeTypeFieldWalker, FieldWalker, PrimaryKeyWalker, ScalarFieldAttributeWalker, ScalarFieldWalker, @@ -331,50 +331,41 @@ pub(super) fn validate_scalar_field_connector_specific(field: ScalarFieldWalker< } } -fn validate_geometry_spec_constraints( - spec: GeometrySpec, +fn require_postgis_capability( ctx: &mut Context<'_>, container: &str, container_name: &str, field_name: &str, field_span: ast::Span, - type_span: ast::Span, ) { - if !ctx.has_capability(ConnectorCapability::PostgisGeometry) { - let msg = format!( - "Field `{field_name}` in {container} `{container_name}` uses type Geometry, which is only supported on PostgreSQL with PostGIS.", - ); - if container == "composite type" { - ctx.push_error(DatamodelError::new_composite_type_validation_error( - &msg, - container_name, - field_span, - )); - } else { - ctx.push_error(DatamodelError::new_field_validation_error( - &msg, - container, - container_name, - field_name, - field_span, - )); - } + if ctx.has_capability(ConnectorCapability::PostgisGeometry) { + return; } - if let Some(srid) = spec.srid - && (srid < 0 || srid > 999_999) - { - ctx.push_error(DatamodelError::new_validation_error( - &format!("Invalid SRID {srid}. Must be between 0 and 999999 when specified."), - type_span, + let msg = format!( + "Field `{field_name}` in {container} `{container_name}` uses type Geometry, which is only supported on PostgreSQL with PostGIS.", + ); + if container == "composite type" { + ctx.push_error(DatamodelError::new_composite_type_validation_error( + &msg, + container_name, + field_span, + )); + } else { + ctx.push_error(DatamodelError::new_field_validation_error( + &msg, + container, + container_name, + field_name, + field_span, )); } } pub(super) fn validate_geometry_field(field: ScalarFieldWalker<'_>, ctx: &mut Context<'_>) { - let ScalarFieldType::Geometry(spec) = field.scalar_field_type() else { + if !field.scalar_field_type().is_geometry() { return; - }; + } let container = if field.model().ast_model().is_view() { "view" @@ -382,30 +373,28 @@ pub(super) fn validate_geometry_field(field: ScalarFieldWalker<'_>, ctx: &mut Co "model" }; - validate_geometry_spec_constraints( - spec, + require_postgis_capability( ctx, container, field.model().name(), field.name(), field.ast_field().span(), - field.ast_field().field_type.span(), ); + // SRID range is enforced by `validate_native_type_arguments` on `PostgresType::Postgis(...)` + // (the only place a structured SRID exists). Nothing else to check at this layer. } pub(super) fn validate_geometry_on_composite_field(field: CompositeTypeFieldWalker<'_>, ctx: &mut Context<'_>) { - let ScalarFieldType::Geometry(spec) = field.r#type() else { + if !field.r#type().is_geometry() { return; - }; + } - validate_geometry_spec_constraints( - spec, + require_postgis_capability( ctx, "composite type", field.composite_type().name(), field.name(), field.ast_field().span(), - field.ast_field().field_type.span(), ); } diff --git a/psl/psl/tests/validation/postgres/postgis_geometry_keyword_valid.prisma b/psl/psl/tests/validation/postgres/postgis_geometry_keyword_valid.prisma new file mode 100644 index 000000000000..6ce4329bf619 --- /dev/null +++ b/psl/psl/tests/validation/postgres/postgis_geometry_keyword_valid.prisma @@ -0,0 +1,17 @@ +datasource db { + provider = "postgresql" +} + +generator client { + provider = "prisma-client" +} + +// SevInf #1/#4: `Geometry` and `Geography` are first-class PSL scalar types. +// Each pairs with the matching `@db.Geometry(...)` / `@db.Geography(...)` native attribute. +model Location { + id Int @id + position Geometry @db.Geometry(Point, 4326) + path Geometry? @db.Geometry(LineString) + region Geography @db.Geography(Polygon, 4326) + marker Geography @db.Geography(Point) +} diff --git a/psl/psl/tests/validation/postgres/postgis_native_type_keyword_mismatch.prisma b/psl/psl/tests/validation/postgres/postgis_native_type_keyword_mismatch.prisma new file mode 100644 index 000000000000..9999cef9fd8b --- /dev/null +++ b/psl/psl/tests/validation/postgres/postgis_native_type_keyword_mismatch.prisma @@ -0,0 +1,27 @@ +datasource db { + provider = "postgresql" +} + +generator client { + provider = "prisma-client" +} + +// SevInf #1/#4: the native attribute must match the PSL keyword. `@db.Geography` cannot decorate +// a `Geometry` field, and `@db.Geometry` cannot decorate a `Geography` field. +model Location { + id Int @id + bad_planar Geometry @db.Geography(Point, 4326) + bad_geodetic Geography @db.Geometry(Point, 4326) +} +// error: Native type Geography is not compatible with declared field type Geometry, expected field type Geography. +// --> schema.prisma:13 +// | +// 12 | id Int @id +// 13 | bad_planar Geometry @db.Geography(Point, 4326) +// | +// error: Native type Geometry is not compatible with declared field type Geography, expected field type Geometry. +// --> schema.prisma:14 +// | +// 13 | bad_planar Geometry @db.Geography(Point, 4326) +// 14 | bad_geodetic Geography @db.Geometry(Point, 4326) +// | diff --git a/psl/schema-ast/src/ast.rs b/psl/schema-ast/src/ast.rs index f762a34d9158..8158e30f73a2 100644 --- a/psl/schema-ast/src/ast.rs +++ b/psl/schema-ast/src/ast.rs @@ -25,7 +25,7 @@ pub use config::ConfigBlockProperty; pub use diagnostics::Span; pub use r#enum::{Enum, EnumValue, EnumValueId}; pub use expression::{Expression, ObjectMember}; -pub use field::{Field, FieldArity, FieldType, GeometrySubtype}; +pub use field::{Field, FieldArity, FieldType}; pub use find_at_position::*; pub use generator_config::GeneratorConfig; pub use identifier::Identifier; diff --git a/psl/schema-ast/src/ast/field.rs b/psl/schema-ast/src/ast/field.rs index c3ba2f131fd4..ba95def2e648 100644 --- a/psl/schema-ast/src/ast/field.rs +++ b/psl/schema-ast/src/ast/field.rs @@ -1,38 +1,9 @@ -use std::fmt::{Display, Write}; +use std::fmt::Display; use super::{ Attribute, Comment, Identifier, Span, WithAttributes, WithDocumentation, WithIdentifier, WithName, WithSpan, }; -/// OGC / PostGIS geometry subtype written in `Geometry(...)` field types. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum GeometrySubtype { - Point, - LineString, - Polygon, - MultiPoint, - MultiLineString, - MultiPolygon, - GeometryCollection, - Geometry, -} - -impl GeometrySubtype { - /// PSL spelling of the subtype (e.g. `Point`). - pub fn as_str(self) -> &'static str { - match self { - GeometrySubtype::Point => "Point", - GeometrySubtype::LineString => "LineString", - GeometrySubtype::Polygon => "Polygon", - GeometrySubtype::MultiPoint => "MultiPoint", - GeometrySubtype::MultiLineString => "MultiLineString", - GeometrySubtype::MultiPolygon => "MultiPolygon", - GeometrySubtype::GeometryCollection => "GeometryCollection", - GeometrySubtype::Geometry => "Geometry", - } - } -} - /// A field definition in a model or a composite type. #[derive(Debug, Clone)] pub struct Field { @@ -178,12 +149,6 @@ impl FieldArity { #[derive(Debug, Clone, PartialEq)] pub enum FieldType { Supported(Identifier), - /// `Geometry(Point, 4326)` or `Geometry(LineString)` (SRID optional). - Geometry { - subtype: GeometrySubtype, - srid: Option, - span: Span, - }, /// Unsupported("...") Unsupported(String, Span), } @@ -192,7 +157,6 @@ impl FieldType { pub fn span(&self) -> Span { match self { FieldType::Supported(ident) => ident.span, - FieldType::Geometry { span, .. } => *span, FieldType::Unsupported(_, span) => *span, } } @@ -200,7 +164,6 @@ impl FieldType { pub fn name(&self) -> &str { match self { FieldType::Supported(supported) => &supported.name, - FieldType::Geometry { .. } => "Geometry", FieldType::Unsupported(name, _) => name, } } @@ -208,7 +171,7 @@ impl FieldType { pub fn as_unsupported(&self) -> Option<(&str, &Span)> { match self { FieldType::Unsupported(name, span) => Some((name, span)), - FieldType::Supported(_) | FieldType::Geometry { .. } => None, + FieldType::Supported(_) => None, } } } @@ -217,14 +180,6 @@ impl Display for FieldType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { FieldType::Supported(ident) => f.write_str(&ident.name), - FieldType::Geometry { subtype, srid, .. } => { - f.write_str("Geometry(")?; - f.write_str(subtype.as_str())?; - if let Some(srid) = srid { - write!(f, ", {srid}")?; - } - f.write_char(')') - } FieldType::Unsupported(name, _) => write!(f, "Unsupported({})", crate::string_literal(name)), } } diff --git a/psl/schema-ast/src/parser/datamodel.pest b/psl/schema-ast/src/parser/datamodel.pest index 266105cbcf6d..8338b427698d 100644 --- a/psl/schema-ast/src/parser/datamodel.pest +++ b/psl/schema-ast/src/parser/datamodel.pest @@ -56,28 +56,7 @@ field_type = { unsupported_optional_list_type | list_type | optional_type | lega unsupported_type = { "Unsupported(" ~ string_literal ~ ")" } -geometry_subtype = { - "MultiPoint" - | "MultiLineString" - | "MultiPolygon" - | "GeometryCollection" - | "Point" - | "LineString" - | "Polygon" - | "Geometry" -} - -geometry_srid = @{ ASCII_DIGIT+ } - -geometry_type = { - "Geometry" - ~ "(" - ~ geometry_subtype - ~ ( "," ~ geometry_srid )? - ~ ")" -} - -base_type = { unsupported_type | geometry_type | identifier } // Called base type to not conflict with type rust keyword +base_type = { unsupported_type | identifier } // Called base type to not conflict with type rust keyword unsupported_optional_list_type = { base_type ~ "[]" ~ "?" } list_type = { base_type ~ "[]" } optional_type = { base_type ~ "?" } diff --git a/psl/schema-ast/src/parser/parse_types.rs b/psl/schema-ast/src/parser/parse_types.rs index 5d6de2393dd3..1391c8b28919 100644 --- a/psl/schema-ast/src/parser/parse_types.rs +++ b/psl/schema-ast/src/parser/parse_types.rs @@ -50,49 +50,6 @@ fn parse_base_type( Expression::StringValue(lit, span) => Ok(FieldType::Unsupported(lit, span)), _ => unreachable!("Encountered impossible type during parsing"), }, - Rule::geometry_type => parse_geometry_type(current, file_id), _ => unreachable!("Encountered impossible type during parsing: {:?}", current.tokens()), } } - -fn parse_geometry_type(pair: Pair<'_>, file_id: FileId) -> Result { - let span = Span::from((file_id, pair.as_span())); - let mut inner = pair.into_inner(); - let subtype_pair = inner.next().expect("geometry: subtype"); - debug_assert_eq!(subtype_pair.as_rule(), Rule::geometry_subtype); - let subtype = match subtype_pair.as_str() { - "Point" => crate::ast::GeometrySubtype::Point, - "LineString" => crate::ast::GeometrySubtype::LineString, - "Polygon" => crate::ast::GeometrySubtype::Polygon, - "MultiPoint" => crate::ast::GeometrySubtype::MultiPoint, - "MultiLineString" => crate::ast::GeometrySubtype::MultiLineString, - "MultiPolygon" => crate::ast::GeometrySubtype::MultiPolygon, - "GeometryCollection" => crate::ast::GeometrySubtype::GeometryCollection, - "Geometry" => crate::ast::GeometrySubtype::Geometry, - _ => unreachable!("geometry_subtype rule produced unexpected token"), - }; - - let srid = if let Some(srid_pair) = inner.next() { - debug_assert_eq!(srid_pair.as_rule(), Rule::geometry_srid); - let raw = srid_pair.as_str(); - match raw.parse::() { - Ok(v) if v >= 0 && v <= 999_999 => Some(v), - Ok(v) => { - return Err(DatamodelError::new_validation_error( - &format!("Invalid SRID: expected a value between 0 and 999999, got {}.", v), - (file_id, srid_pair.as_span()).into(), - )); - } - Err(_) => { - return Err(DatamodelError::new_validation_error( - "Invalid SRID: expected a valid 32-bit integer.", - (file_id, srid_pair.as_span()).into(), - )); - } - } - } else { - None - }; - - Ok(FieldType::Geometry { subtype, srid, span }) -} diff --git a/quaint/.github/workflows/test.yml b/quaint/.github/workflows/test.yml index cb8cf5a88cf2..dd9a1b8e96c4 100644 --- a/quaint/.github/workflows/test.yml +++ b/quaint/.github/workflows/test.yml @@ -8,14 +8,14 @@ jobs: clippy: runs-on: ubuntu-latest env: - RUSTFLAGS: '-Dwarnings' + RUSTFLAGS: "-Dwarnings" steps: - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: - components: clippy - override: true - toolchain: stable + components: clippy + override: true + toolchain: stable - name: Install dependencies run: sudo apt install -y openssl libkrb5-dev - uses: actions-rs/clippy-check@v1 @@ -44,24 +44,24 @@ jobs: fail-fast: false matrix: features: - - '--lib --features=all' - - '--lib --no-default-features --features=sqlite' - - '--lib --no-default-features --features=sqlite --features=pooled' - - '--lib --no-default-features --features=postgresql' - - '--lib --no-default-features --features=postgresql --features=pooled' - - '--lib --no-default-features --features=mysql' - - '--lib --no-default-features --features=mysql --features=pooled' - - '--lib --no-default-features --features=mssql' - - '--lib --no-default-features --features=mssql --features=pooled' - - '--doc --features=all' + - "--lib --features=all" + - "--lib --no-default-features --features=sqlite" + - "--lib --no-default-features --features=sqlite --features=pooled" + - "--lib --no-default-features --features=postgresql" + - "--lib --no-default-features --features=postgresql --features=pooled" + - "--lib --no-default-features --features=mysql" + - "--lib --no-default-features --features=mysql --features=pooled" + - "--lib --no-default-features --features=mssql" + - "--lib --no-default-features --features=mssql --features=pooled" + - "--doc --features=all" env: - TEST_MYSQL: 'mysql://root:prisma@localhost:3306/prisma' - TEST_MYSQL8: 'mysql://root:prisma@localhost:3307/prisma' - TEST_MYSQL_MARIADB: 'mysql://root:prisma@localhost:3308/prisma' - TEST_PSQL: 'postgres://postgres:prisma@localhost:5432/postgres' - TEST_MSSQL: 'jdbc:sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true' - TEST_CRDB: 'postgresql://prisma@127.0.0.1:26259/postgres' - RUSTFLAGS: '-Dwarnings' + TEST_MYSQL: "mysql://root:prisma@localhost:3306/prisma" + TEST_MYSQL8: "mysql://root:prisma@localhost:3307/prisma" + TEST_MYSQL_MARIADB: "mysql://root:prisma@localhost:3308/prisma" + TEST_PSQL: "postgres://postgres:prisma@localhost:5432/postgres" + TEST_MSSQL: "jdbc:sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true" + TEST_CRDB: "postgresql://prisma@127.0.0.1:26259/postgres" + RUSTFLAGS: "-Dwarnings" steps: - uses: actions/checkout@v4 diff --git a/quaint/src/ast/function.rs b/quaint/src/ast/function.rs index 38c250fe07de..e269e7b2e5ae 100644 --- a/quaint/src/ast/function.rs +++ b/quaint/src/ast/function.rs @@ -11,6 +11,8 @@ mod json_unquote; mod lower; mod maximum; mod minimum; +#[cfg(feature = "postgresql")] +mod postgis; mod row_number; mod row_to_json; mod search; @@ -33,6 +35,8 @@ pub use json_unquote::*; pub use lower::*; pub use maximum::*; pub use minimum::*; +#[cfg(feature = "postgresql")] +pub use postgis::*; pub use row_number::*; pub use row_to_json::*; pub use search::*; @@ -90,6 +94,8 @@ pub(crate) enum FunctionType<'a> { JsonBuildObject(JsonBuildObject<'a>), TextSearch(TextSearch<'a>), TextSearchRelevance(TextSearchRelevance<'a>), + #[cfg(feature = "postgresql")] + Postgis(postgis::PostgisFunction<'a>), UuidToBin, UuidToBinSwapped, Uuid, @@ -116,6 +122,8 @@ impl<'a> FunctionType<'a> { Self::TextSearch(f) => &f.exprs, Self::TextSearchRelevance(f) => &f.exprs, Self::Stringify(f) => slice::from_ref(&f.expression), + #[cfg(feature = "postgresql")] + Self::Postgis(f) => &f.args, Self::RowToJson(_) | Self::RowNumber(_) | Self::Average(_) @@ -157,6 +165,8 @@ impl<'a> FunctionType<'a> { | Self::UuidToBinSwapped | Self::Uuid | Self::Stringify(_) => return None, + #[cfg(feature = "postgresql")] + Self::Postgis(_) => return None, }; Some(name) } diff --git a/quaint/src/ast/function/postgis.rs b/quaint/src/ast/function/postgis.rs new file mode 100644 index 000000000000..a812a236f568 --- /dev/null +++ b/quaint/src/ast/function/postgis.rs @@ -0,0 +1,104 @@ +use super::Function; +use crate::ast::Expression; + +/// A PostGIS function call. Used to render spatial SQL expressions (`ST_*`) with each argument +/// going through the regular parameterized-expression path of the visitor, so no user input +/// is ever interpolated into raw SQL text. +#[derive(Debug, Clone, PartialEq)] +pub struct PostgisFunction<'a> { + pub(crate) name: &'static str, + pub(crate) args: Vec>, +} + +impl<'a> PostgisFunction<'a> { + pub(crate) fn build(name: &'static str, args: Vec>) -> Function<'a> { + Function { + typ_: super::FunctionType::Postgis(Self { name, args }), + alias: None, + } + } + + pub fn name(&self) -> &'static str { + self.name + } + + pub fn args(&self) -> &[Expression<'a>] { + &self.args + } +} + +/// `ST_DWithin(geom_a, geom_b, distance)` - returns `true` when the geometries are within +/// `distance` meters/units of each other. +pub fn st_dwithin<'a, A, B, C>(geom_a: A, geom_b: B, distance: C) -> Function<'a> +where + A: Into>, + B: Into>, + C: Into>, +{ + PostgisFunction::build("ST_DWithin", vec![geom_a.into(), geom_b.into(), distance.into()]) +} + +/// `ST_Within(geom_a, geom_b)` - returns `true` if `geom_a` is completely contained inside +/// `geom_b`. +pub fn st_within<'a, A, B>(geom_a: A, geom_b: B) -> Function<'a> +where + A: Into>, + B: Into>, +{ + PostgisFunction::build("ST_Within", vec![geom_a.into(), geom_b.into()]) +} + +/// `ST_Intersects(geom_a, geom_b)` - returns `true` if the geometries share any point. +pub fn st_intersects<'a, A, B>(geom_a: A, geom_b: B) -> Function<'a> +where + A: Into>, + B: Into>, +{ + PostgisFunction::build("ST_Intersects", vec![geom_a.into(), geom_b.into()]) +} + +/// `ST_Distance(geom_a, geom_b)` - returns the minimum distance between the geometries. +pub fn st_distance<'a, A, B>(geom_a: A, geom_b: B) -> Function<'a> +where + A: Into>, + B: Into>, +{ + PostgisFunction::build("ST_Distance", vec![geom_a.into(), geom_b.into()]) +} + +/// `ST_GeomFromText(wkt, srid)` - parses a WKT string into a geometry with the given SRID. +pub fn st_geom_from_text<'a, A, B>(wkt: A, srid: B) -> Function<'a> +where + A: Into>, + B: Into>, +{ + PostgisFunction::build("ST_GeomFromText", vec![wkt.into(), srid.into()]) +} + +/// `ST_MakePoint(x, y)` - constructs a 2D point. +pub fn st_make_point<'a, A, B>(x: A, y: B) -> Function<'a> +where + A: Into>, + B: Into>, +{ + PostgisFunction::build("ST_MakePoint", vec![x.into(), y.into()]) +} + +/// `ST_SetSRID(geom, srid)` - assigns/overrides the SRID of a geometry without reprojecting. +pub fn st_set_srid<'a, A, B>(geom: A, srid: B) -> Function<'a> +where + A: Into>, + B: Into>, +{ + PostgisFunction::build("ST_SetSRID", vec![geom.into(), srid.into()]) +} + +/// `geography(geom)` - PostGIS conversion from `geometry` to `geography`. Equivalent to the +/// `::geography` cast but expressible inside the Function AST so the operand stays a regular +/// parameterized expression. +pub fn geography_cast<'a, A>(geom: A) -> Function<'a> +where + A: Into>, +{ + PostgisFunction::build("geography", vec![geom.into()]) +} diff --git a/quaint/src/visitor.rs b/quaint/src/visitor.rs index 56d798fc0c02..ab17400979f8 100644 --- a/quaint/src/visitor.rs +++ b/quaint/src/visitor.rs @@ -1279,6 +1279,22 @@ pub trait Visitor<'a> { FunctionType::Stringify(stringify) => { self.visit_stringify(stringify)?; } + #[cfg(feature = "postgresql")] + FunctionType::Postgis(postgis) => { + let name = postgis.name; + let args = postgis.args; + self.write(name)?; + self.surround_with("(", ")", |this| { + let last = args.len().saturating_sub(1); + for (i, arg) in args.into_iter().enumerate() { + this.visit_expression(arg)?; + if i < last { + this.write(", ")?; + } + } + Ok(()) + })?; + } }; if let Some(alias) = fun.alias { diff --git a/query-compiler/core-tests/tests/geometry-filters-graph-builds.rs b/query-compiler/core-tests/tests/geometry-filters-graph-builds.rs index ace8692513e4..c31bfcd4a0c1 100644 --- a/query-compiler/core-tests/tests/geometry-filters-graph-builds.rs +++ b/query-compiler/core-tests/tests/geometry-filters-graph-builds.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use query_core::{QueryDocument, QueryGraphBuilder}; +use query_core::{QueryDocument, QueryGraphBuilder, with_sync_unevaluated_request_context}; use request_handlers::{JsonBody, JsonSingleQuery, RequestBody}; #[test] @@ -16,7 +16,7 @@ fn geometry_near_filter_builds_query_graph() { model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } "#; @@ -69,7 +69,7 @@ fn geometry_within_filter_builds_query_graph() { model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } "#; @@ -127,7 +127,7 @@ fn geometry_orderby_distance_builds_query_graph() { model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } "#; @@ -182,7 +182,7 @@ fn geometry_combined_filter_and_orderby_builds_query_graph() { model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } "#; @@ -245,7 +245,7 @@ fn geometry_not_filter_builds_query_graph() { model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } "#; @@ -300,7 +300,7 @@ fn geometry_or_filter_builds_query_graph() { model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } "#; @@ -365,7 +365,7 @@ fn geometry_custom_srid_builds_query_graph() { model LocationMercator { id Int @id @default(autoincrement()) - position Geometry(Point, 3857)? + position Geometry? @db.Geometry(Point, 3857) } "#; @@ -409,3 +409,294 @@ fn geometry_custom_srid_builds_query_graph() { .build(query) .expect("findMany with custom SRID 3857 should compile to a query graph"); } + +/// Combining cursor-based pagination with `orderBy: { position: { distance: { ... } } }` has no +/// deterministic SQL translation (the distance reference point is not part of the row). The +/// extractor must reject the combination with a clear `InputError` instead of letting it reach +/// the SQL builder where it would panic with `unimplemented!()`. +#[test] +fn cursor_with_geometry_orderby_is_rejected() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry @db.Geometry(Point, 4326) + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "cursor": { "id": 1 }, + "orderBy": [ + { "position": { "distanceFrom": { "point": [0, 0], "direction": "asc" } } } + ] + }, + "selection": { + "id": true, + "position": true + } + } + }"#; + + with_sync_unevaluated_request_context(|| { + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + let error = QueryGraphBuilder::new(&query_schema) + .build(query) + .expect_err("cursor + geometry orderBy must be rejected"); + + let message = format!("{error}"); + assert!( + message.contains("Cursor-based pagination") && message.to_ascii_lowercase().contains("geometry"), + "unexpected error message: {message}", + ); + }); +} + +/// Malformed GeoJSON in an `intersects` filter must surface as a user-facing `InputError` instead +/// of crashing the builder (older code used `panic!()` for polygons missing the closing vertex). +#[test] +fn intersects_with_malformed_polygon_returns_input_error() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry @db.Geometry(Point, 4326) + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + // Polygon ring with only two distinct positions — auto-close cannot rescue it, and the + // extractor must surface a parse error rather than panicking. + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "intersects": { + "geometry": { + "type": "Polygon", + "coordinates": [[[0, 0], [1, 1]]] + }, + "srid": 4326 + } + } + } + }, + "selection": { "id": true } + } + }"#; + + with_sync_unevaluated_request_context(|| { + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + let error = QueryGraphBuilder::new(&query_schema) + .build(query) + .expect_err("malformed polygon must be rejected"); + + let message = format!("{error}"); + assert!( + message.to_ascii_lowercase().contains("polygon"), + "expected error to mention the polygon ring, got: {message}", + ); + }); +} + +/// `intersects` filter must reject `Multi*` / `GeometryCollection` GeoJSON shapes at extraction +/// time. The previous behaviour silently emitted a `false` predicate inside the SQL builder, which +/// made empty result sets indistinguishable from "filter ignored". This regression test pins the +/// fast-fail path so future contributors don't reintroduce the silent fallback. +#[test] +fn intersects_with_unsupported_geojson_returns_input_error() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + model Location { + id Int @id @default(autoincrement()) + position Geometry @db.Geometry(Point, 4326) + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Location", + "action": "findMany", + "query": { + "arguments": { + "where": { + "position": { + "intersects": { + "geometry": { + "type": "MultiPoint", + "coordinates": [[0.0, 0.0], [1.0, 1.0]] + }, + "srid": 4326 + } + } + } + }, + "selection": { "id": true } + } + }"#; + + with_sync_unevaluated_request_context(|| { + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + let error = QueryGraphBuilder::new(&query_schema) + .build(query) + .expect_err("MultiPoint geometry must be rejected by the extractor"); + + let message = format!("{error}"); + assert!( + message.contains("MultiPoint"), + "expected error to mention the unsupported GeoJSON type, got: {message}", + ); + }); +} + +/// SevInf #1/#4: `Geography` is a first-class PSL scalar type. The geometry filter pipeline must +/// accept fields declared as `Geography @db.Geography(...)` end-to-end and route them through the +/// same near/within/intersects machinery as `Geometry` fields. +#[test] +fn geography_field_supports_full_filter_pipeline() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + generator client { + provider = "prisma-client" + } + + model Place { + id Int @id @default(autoincrement()) + location Geography @db.Geography(Point, 4326) + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + assert!(!schema.diagnostics.has_errors(), "{:?}", schema.diagnostics); + + let schema = Arc::new(schema); + let query_schema = Arc::new(query_core::schema::build(schema, true)); + + let query_json = r#"{ + "modelName": "Place", + "action": "findMany", + "query": { + "arguments": { + "where": { + "location": { + "near": { + "point": [2.35, 48.85], + "maxDistance": 100000 + } + } + }, + "orderBy": [ + { "location": { "distanceFrom": { "point": [0, 0], "direction": "asc" } } } + ] + }, + "selection": { + "id": true, + "location": true + } + } + }"#; + + with_sync_unevaluated_request_context(|| { + let query: JsonSingleQuery = serde_json::from_str(query_json).unwrap(); + let request = RequestBody::Json(JsonBody::Single(query)); + let doc = request.into_doc(&query_schema).unwrap(); + + let QueryDocument::Single(query) = doc else { + panic!("expected single query"); + }; + + QueryGraphBuilder::new(&query_schema) + .build(query) + .expect("Geography field with near + orderBy should compile to a query graph"); + }); +} + +/// SevInf #1/#4: `Geometry @db.Geography(...)` (and the symmetric mismatch) is rejected at PSL +/// validation time so the rest of the toolchain never has to reason about an inconsistent +/// `geometry`/`geography` declaration. +#[test] +fn mismatched_geometry_geography_native_attribute_is_rejected() { + let schema_string = r#" + datasource db { + provider = "postgresql" + } + + generator client { + provider = "prisma-client" + } + + model Bad { + id Int @id + loc Geometry @db.Geography(Point, 4326) + } + "#; + + let schema = psl::validate_without_extensions(schema_string.into()); + let messages: Vec = schema + .diagnostics + .errors() + .iter() + .map(|e| e.message().to_owned()) + .collect(); + assert!( + messages + .iter() + .any(|m| m.contains("Native type Geography is not compatible with declared field type Geometry")), + "expected mismatch error, got: {messages:?}", + ); +} diff --git a/query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs b/query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs index 3aa8dcb3a896..4b91ee5a8d36 100644 --- a/query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs +++ b/query-compiler/core-tests/tests/geometry_find_many_graph_builds.rs @@ -11,13 +11,12 @@ fn geometry_find_many_builds_query_graph() { } generator client { - provider = "prisma-client" - previewFeatures = ["relationJoins"] + provider = "prisma-client" } model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326) + position Geometry @db.Geometry(Point, 4326) } "#; diff --git a/query-compiler/core/src/query_document/parser.rs b/query-compiler/core/src/query_document/parser.rs index ff6ce1344375..afe41cc7a1c1 100644 --- a/query-compiler/core/src/query_document/parser.rs +++ b/query-compiler/core/src/query_document/parser.rs @@ -420,7 +420,6 @@ impl QueryDocumentParser { (PrismaValue::Bytes(bytes), &ScalarType::Bytes) => Ok(PrismaValue::Bytes(bytes)), (pv @ PrismaValue::Bytes(_), &ScalarType::Geometry(_)) => Ok(pv), (pv @ PrismaValue::String(_), &ScalarType::Geometry(_)) => Ok(pv), - (PrismaValue::Json(s), &ScalarType::Geometry(_)) => Ok(PrismaValue::Bytes(s.into_bytes())), (PrismaValue::BigInt(b_int), &ScalarType::BigInt) => Ok(PrismaValue::BigInt(b_int)), (PrismaValue::DateTime(s), &ScalarType::DateTime) => Ok(PrismaValue::DateTime(s)), (PrismaValue::Null, &ScalarType::Null) => Ok(PrismaValue::Null), diff --git a/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs b/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs index 5b9b48d2c9a7..fbc8591e9685 100644 --- a/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs +++ b/query-compiler/core/src/query_graph_builder/extractors/filters/scalar.rs @@ -672,12 +672,18 @@ fn parse_geometry_near(field: &ScalarFieldRef, mut input_map: ParsedInputMap<'_> let lon = extract_float(&point_list[0])?; let lat = extract_float(&point_list[1])?; let max_distance = extract_float(&max_distance_value.try_into()?)?; + if max_distance < 0.0 { + return Err(QueryGraphBuilderError::InputError( + "near filter `maxDistance` must be a non-negative number".to_owned(), + )); + } + let point = GeoCoord::new(lon, lat).map_err(|e| QueryGraphBuilderError::InputError(e.to_string()))?; let srid = srid_value.map(|v| extract_int(&v.try_into()?)).transpose()?; Ok(Filter::Geometry(GeometryFilter { field: field.clone(), condition: GeometryFilterCondition::Near { - point: (lon, lat), + point, max_distance, srid, }, @@ -691,7 +697,7 @@ fn parse_geometry_within(field: &ScalarFieldRef, mut input_map: ParsedInputMap<' let srid_value = input_map.swap_remove(filters::SRID); let polygon_outer: Vec = polygon_value.try_into()?; - let mut polygon = Vec::with_capacity(polygon_outer.len()); + let mut polygon = Vec::with_capacity(polygon_outer.len() + 1); for coord in polygon_outer { if let PrismaValue::List(pair) = coord { @@ -702,7 +708,8 @@ fn parse_geometry_within(field: &ScalarFieldRef, mut input_map: ParsedInputMap<' } let lon = extract_float(&pair[0])?; let lat = extract_float(&pair[1])?; - polygon.push((lon, lat)); + let position = GeoCoord::new(lon, lat).map_err(|e| QueryGraphBuilderError::InputError(e.to_string()))?; + polygon.push(position); } else { return Err(QueryGraphBuilderError::InputError( "polygon must be an array of coordinate pairs".to_owned(), @@ -710,6 +717,21 @@ fn parse_geometry_within(field: &ScalarFieldRef, mut input_map: ParsedInputMap<' } } + // Match PostGIS leniency: auto-close a ring whose user-supplied first and last positions + // differ. Rings shorter than 3 distinct positions are rejected because they cannot define + // any area. + if polygon.len() < 3 { + return Err(QueryGraphBuilderError::InputError(format!( + "within filter polygon must contain at least 3 distinct positions, got {}", + polygon.len() + ))); + } + if polygon.first() != polygon.last() { + let first = *polygon.first().expect("polygon was just validated to be non-empty"); + polygon.push(first); + } + debug_assert!(polygon.len() >= 4); + let srid = srid_value.map(|v| extract_int(&v.try_into()?)).transpose()?; Ok(Filter::Geometry(GeometryFilter { @@ -728,7 +750,7 @@ fn parse_geometry_intersects( let srid_value = input_map.swap_remove(filters::SRID); let geometry_json: PrismaValue = geometry_value.try_into()?; - let geometry = match geometry_json { + let raw_value: serde_json::Value = match geometry_json { PrismaValue::Json(json_str) => serde_json::from_str(&json_str) .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid GeoJSON: {}", e)))?, PrismaValue::Object(obj) => serde_json::to_value(obj) @@ -736,10 +758,25 @@ fn parse_geometry_intersects( _ => { return Err(QueryGraphBuilderError::InputError( "intersects geometry must be a JSON value".to_owned(), - )) + )); } }; + let geometry = + GeoJsonGeometry::from_serde_value(&raw_value).map_err(|e| QueryGraphBuilderError::InputError(e.to_string()))?; + + // Reject GeoJSON shapes the SQL builder cannot lower to a single WKT (`Multi*` / + // `GeometryCollection`). `to_wkt()` returning `None` is the canonical signal: instead of + // letting the visitor silently emit a `false` predicate (which used to be the behaviour), + // surface a clear input error so callers can distinguish "no rows match" from "we never + // executed the requested filter". + if geometry.to_wkt().is_none() { + return Err(QueryGraphBuilderError::InputError(format!( + "intersects filter does not yet support GeoJSON `{}` geometries; use Point, LineString or Polygon.", + geometry.type_tag() + ))); + } + let srid = srid_value.map(|v| extract_int(&v.try_into()?)).transpose()?; Ok(Filter::Geometry(GeometryFilter { @@ -749,27 +786,47 @@ fn parse_geometry_intersects( } fn extract_float(value: &PrismaValue) -> QueryGraphBuilderResult { - match value { - PrismaValue::Int(i) => Ok(*i as f64), - PrismaValue::BigInt(i) => Ok(*i as f64), + let result = match value { + PrismaValue::Int(i) => *i as f64, + PrismaValue::BigInt(i) => *i as f64, PrismaValue::Float(d) => d .to_string() .parse::() - .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid float value: {}", e))), - _ => Err(QueryGraphBuilderError::InputError(format!( - "Expected numeric value, got {:?}", - value - ))), + .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid float value: {}", e)))?, + _ => { + return Err(QueryGraphBuilderError::InputError(format!( + "Expected numeric value, got {:?}", + value + ))); + } + }; + + if !result.is_finite() { + return Err(QueryGraphBuilderError::InputError(format!( + "Expected finite numeric value, got {}", + result + ))); } + + Ok(result) } fn extract_int(value: &PrismaValue) -> QueryGraphBuilderResult { - match value { - PrismaValue::Int(i) => Ok(*i as i32), - PrismaValue::BigInt(i) => Ok(*i as i32), - _ => Err(QueryGraphBuilderError::InputError(format!( - "Expected integer value, got {:?}", - value - ))), - } + let i64_value = match value { + PrismaValue::Int(i) => *i, + PrismaValue::BigInt(i) => *i, + _ => { + return Err(QueryGraphBuilderError::InputError(format!( + "Expected integer value, got {:?}", + value + ))); + } + }; + + i32::try_from(i64_value).map_err(|_| { + QueryGraphBuilderError::InputError(format!( + "Integer value {} is out of range for a 32-bit signed integer", + i64_value + )) + }) } diff --git a/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs b/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs index b3108d33e6df..c0983eb4af05 100644 --- a/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs +++ b/query-compiler/core/src/query_graph_builder/extractors/query_arguments.rs @@ -138,12 +138,11 @@ fn process_order_object( } Field::Scalar(sf) => { - if matches!(sf.type_identifier(), TypeIdentifier::Geometry(_)) { - if let ParsedInputValue::Map(ref map) = field_value { - if let Some(distance_from_value) = map.get(ordering::DISTANCE_FROM) { - return extract_geometry_distance_from(&sf, distance_from_value.clone(), path); - } - } + if matches!(sf.type_identifier(), TypeIdentifier::Geometry(_)) + && let ParsedInputValue::Map(ref map) = field_value + && let Some(distance_from_value) = map.get(ordering::DISTANCE_FROM) + { + return extract_geometry_distance_from(&sf, distance_from_value.clone(), path); } let (sort_order, nulls_order) = extract_order_by_args(field_value)?; @@ -343,6 +342,21 @@ fn extract_compound_cursor_field( /// Runs final transformations on the QueryArguments. fn finalize_arguments(mut args: QueryArguments, model: &Model) -> QueryGraphBuilderResult { + // Cursor-based pagination needs a deterministic, totally ordered set of cursor columns. Geometry + // distance ordering is not deterministic (depends on a reference point that is not part of the + // row), so combining the two would silently produce undefined results. Reject upfront with a clear + // error rather than letting the SQL builder reach the geometry-cursor code path. + if args.cursor.is_some() + && args + .order_by + .iter() + .any(|order_by| matches!(order_by, OrderBy::Geometry(_))) + { + return Err(QueryGraphBuilderError::InputError( + "Cursor-based pagination is not supported when ordering by geometry distance.".to_owned(), + )); + } + // Check if the query requires an implicit ordering added to the arguments. // An implicit ordering is convenient for deterministic results for take and skip, for cursor it's _required_ // as a cursor needs a direction to page. We simply take the primary identifier as a default order-by. @@ -397,31 +411,57 @@ fn extract_geometry_distance_from( let sort_order = pv_to_sort_order(direction_value.try_into()?)?; let srid = srid_value.map(|v| extract_int_from_pv(&v.try_into()?)).transpose()?; - Ok(Some(OrderBy::geometry(field.clone(), path, (lon, lat), sort_order, srid))) + Ok(Some(OrderBy::geometry( + field.clone(), + path, + (lon, lat), + sort_order, + srid, + ))) } fn extract_float_from_pv(value: &PrismaValue) -> QueryGraphBuilderResult { - match value { - PrismaValue::Int(i) => Ok(*i as f64), - PrismaValue::BigInt(i) => Ok(*i as f64), + let result = match value { + PrismaValue::Int(i) => *i as f64, + PrismaValue::BigInt(i) => *i as f64, PrismaValue::Float(d) => d .to_string() .parse::() - .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid float value: {}", e))), - _ => Err(QueryGraphBuilderError::InputError(format!( - "Expected numeric value, got {:?}", - value - ))), + .map_err(|e| QueryGraphBuilderError::InputError(format!("Invalid float value: {}", e)))?, + _ => { + return Err(QueryGraphBuilderError::InputError(format!( + "Expected numeric value, got {:?}", + value + ))); + } + }; + + if !result.is_finite() { + return Err(QueryGraphBuilderError::InputError(format!( + "Expected finite numeric value, got {}", + result + ))); } + + Ok(result) } fn extract_int_from_pv(value: &PrismaValue) -> QueryGraphBuilderResult { - match value { - PrismaValue::Int(i) => Ok(*i as i32), - PrismaValue::BigInt(i) => Ok(*i as i32), - _ => Err(QueryGraphBuilderError::InputError(format!( - "Expected integer value, got {:?}", - value - ))), - } + let i64_value = match value { + PrismaValue::Int(i) => *i, + PrismaValue::BigInt(i) => *i, + _ => { + return Err(QueryGraphBuilderError::InputError(format!( + "Expected integer value, got {:?}", + value + ))); + } + }; + + i32::try_from(i64_value).map_err(|_| { + QueryGraphBuilderError::InputError(format!( + "Integer value {} is out of range for a 32-bit signed integer", + i64_value + )) + }) } diff --git a/query-compiler/dmmf/Cargo.toml b/query-compiler/dmmf/Cargo.toml index fab7ed2e7cd2..ff828f64732f 100644 --- a/query-compiler/dmmf/Cargo.toml +++ b/query-compiler/dmmf/Cargo.toml @@ -6,7 +6,7 @@ edition.workspace = true [dependencies] bigdecimal.workspace = true itertools.workspace = true -psl = { workspace = true, features = ["all"] } +psl.workspace = true serde.workspace = true serde_json.workspace = true schema.workspace = true diff --git a/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs b/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs index 31b5011ae309..e53a5b727ab8 100644 --- a/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs +++ b/query-compiler/dmmf/src/ast_builders/datamodel_ast_builder.rs @@ -5,14 +5,12 @@ use crate::serialization_ast::{ use bigdecimal::ToPrimitive; use itertools::{Either, Itertools}; use psl::{ - parser_database::{GeometrySpec, ScalarFieldType, walkers}, + datamodel_connector::Connector, + parser_database::{ScalarFieldType, walkers}, schema_ast::ast::WithDocumentation, }; use query_structure::{DefaultKind, FieldArity, PrismaValue, dml_default_kind, encode_bytes}; -fn geometry_dmmf_field_type(spec: &GeometrySpec) -> String { - spec.postgres_sql_type() -} pub(crate) fn schema_to_dmmf(schema: &psl::ValidatedSchema) -> Datamodel { let mut datamodel = Datamodel { @@ -26,18 +24,19 @@ pub(crate) fn schema_to_dmmf(schema: &psl::ValidatedSchema) -> Datamodel { datamodel.enums.push(enum_to_dmmf(enum_model)); } + let connector = schema.connector; for model in schema .db .walk_models() .filter(|model| !model.is_ignored()) .chain(schema.db.walk_views().filter(|view| !view.is_ignored())) { - datamodel.models.push(model_to_dmmf(model)); + datamodel.models.push(model_to_dmmf(model, connector)); datamodel.indexes.extend(model_indexes_to_dmmf(model)); } for ct in schema.db.walk_composite_types() { - datamodel.types.push(composite_type_to_dmmf(ct)) + datamodel.types.push(composite_type_to_dmmf(ct, connector)) } datamodel @@ -66,7 +65,7 @@ fn enum_value_to_dmmf(en: walkers::EnumValueWalker<'_>) -> EnumValue { } } -fn composite_type_to_dmmf(ct: walkers::CompositeTypeWalker<'_>) -> Model { +fn composite_type_to_dmmf(ct: walkers::CompositeTypeWalker<'_>, connector: &'static dyn Connector) -> Model { Model { name: ct.name().to_owned(), db_name: None, @@ -74,7 +73,7 @@ fn composite_type_to_dmmf(ct: walkers::CompositeTypeWalker<'_>) -> Model { fields: ct .fields() .filter(|field| !matches!(field.r#type(), ScalarFieldType::Unsupported(_))) - .map(composite_type_field_to_dmmf) + .map(|f| composite_type_field_to_dmmf(f, connector)) .collect(), is_generated: None, documentation: ct.ast_composite_type().documentation().map(ToOwned::to_owned), @@ -84,13 +83,16 @@ fn composite_type_to_dmmf(ct: walkers::CompositeTypeWalker<'_>) -> Model { } } -fn composite_type_field_to_dmmf(field: walkers::CompositeTypeFieldWalker<'_>) -> Field { +fn composite_type_field_to_dmmf( + field: walkers::CompositeTypeFieldWalker<'_>, + _connector: &'static dyn Connector, +) -> Field { Field { name: field.name().to_owned(), kind: match field.r#type() { ScalarFieldType::CompositeType(_) => "object", ScalarFieldType::Enum(_) => "enum", - ScalarFieldType::BuiltInScalar(_) | ScalarFieldType::Geometry(_) => "scalar", + ScalarFieldType::BuiltInScalar(_) => "scalar", ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, db_name: field.mapped_name().map(ToOwned::to_owned), @@ -115,7 +117,6 @@ fn composite_type_field_to_dmmf(field: walkers::CompositeTypeFieldWalker<'_>) -> ScalarFieldType::CompositeType(ct) => field.walk(ct).name().to_owned(), ScalarFieldType::Enum(enm) => field.walk(enm).name().to_owned(), ScalarFieldType::BuiltInScalar(st) => st.as_str().to_owned(), - ScalarFieldType::Geometry(spec) => geometry_dmmf_field_type(&spec), ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, is_generated: None, @@ -124,7 +125,7 @@ fn composite_type_field_to_dmmf(field: walkers::CompositeTypeFieldWalker<'_>) -> } } -fn model_to_dmmf(model: walkers::ModelWalker<'_>) -> Model { +fn model_to_dmmf(model: walkers::ModelWalker<'_>, connector: &'static dyn Connector) -> Model { let primary_key = if let Some(pk) = model.primary_key() { (!pk.is_defined_on_field()).then(|| PrimaryKey { name: pk.name().map(ToOwned::to_owned), @@ -141,7 +142,7 @@ fn model_to_dmmf(model: walkers::ModelWalker<'_>) -> Model { fields: model .fields() .filter(|field| !should_skip_model_field(field)) - .map(field_to_dmmf) + .map(|f| field_to_dmmf(f, connector)) .collect(), is_generated: Some(false), documentation: model.ast_model().documentation().map(ToOwned::to_owned), @@ -169,14 +170,14 @@ fn should_skip_model_field(field: &walkers::FieldWalker<'_>) -> bool { } } -fn field_to_dmmf(field: walkers::FieldWalker<'_>) -> Field { +fn field_to_dmmf(field: walkers::FieldWalker<'_>, connector: &'static dyn Connector) -> Field { match field.refine_known() { - walkers::RefinedFieldWalker::Scalar(sf) => scalar_field_to_dmmf(sf), + walkers::RefinedFieldWalker::Scalar(sf) => scalar_field_to_dmmf(sf, connector), walkers::RefinedFieldWalker::Relation(rf) => relation_field_to_dmmf(rf), } } -fn scalar_field_to_dmmf(field: walkers::ScalarFieldWalker<'_>) -> Field { +fn scalar_field_to_dmmf(field: walkers::ScalarFieldWalker<'_>, _connector: &'static dyn Connector) -> Field { let ast_field = field.ast_field(); let field_walker = walkers::FieldWalker::from(field); let is_id = field.is_single_pk(); @@ -186,7 +187,7 @@ fn scalar_field_to_dmmf(field: walkers::ScalarFieldWalker<'_>) -> Field { kind: match field.scalar_field_type() { ScalarFieldType::CompositeType(_) => "object", ScalarFieldType::Enum(_) => "enum", - ScalarFieldType::BuiltInScalar(_) | ScalarFieldType::Geometry(_) => "scalar", + ScalarFieldType::BuiltInScalar(_) => "scalar", ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, is_list: ast_field.arity.is_list(), @@ -204,7 +205,6 @@ fn scalar_field_to_dmmf(field: walkers::ScalarFieldWalker<'_>) -> Field { ScalarFieldType::CompositeType(ct) => field_walker.walk(ct).name().to_owned(), ScalarFieldType::Enum(enm) => field_walker.walk(enm).name().to_owned(), ScalarFieldType::BuiltInScalar(st) => st.as_str().to_owned(), - ScalarFieldType::Geometry(spec) => geometry_dmmf_field_type(&spec), ScalarFieldType::Extension(_) | ScalarFieldType::Unsupported(_) => unreachable!(), }, native_type: field diff --git a/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs b/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs index 52736c6da94a..e74b6790961c 100644 --- a/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs +++ b/query-compiler/dmmf/src/ast_builders/schema_ast_builder/type_renderer.rs @@ -46,7 +46,7 @@ pub(super) fn render_output_type<'a>(output_type: &OutputType<'a>, ctx: &mut Ren ScalarType::UUID => "UUID".into(), ScalarType::JsonList => "Json".into(), ScalarType::Bytes => "Bytes".into(), - ScalarType::Geometry(s) => s.clone(), + ScalarType::Geometry(spec) => spec.psl_type_name().to_owned(), }; DmmfTypeReference { diff --git a/query-compiler/dmmf/src/tests/tests.rs b/query-compiler/dmmf/src/tests/tests.rs index 0449ea7b84b0..6e909f735c80 100644 --- a/query-compiler/dmmf/src/tests/tests.rs +++ b/query-compiler/dmmf/src/tests/tests.rs @@ -12,9 +12,10 @@ fn geometry_fields_in_datamodel_and_schema_dmmf() { } model Location { - id Int @id - position Geometry(Point, 4326) - path Geometry(LineString)? + id Int @id + position Geometry @db.Geometry(Point, 4326) + path Geometry? @db.Geometry(LineString) + footprint Geography @db.Geography(Polygon, 4326) } "#; @@ -25,10 +26,33 @@ fn geometry_fields_in_datamodel_and_schema_dmmf() { .iter() .find(|m| m.name == "Location") .expect("Location model"); + + // SevInf #8: DMMF `field_type` mirrors the PSL keyword (matches built-in scalar emission) + // and the structured native attribute arguments live in `native_type` instead of being + // baked into the type string. let pos = location.fields.iter().find(|f| f.name == "position").unwrap(); - assert_eq!(pos.field_type, "geometry(Point,4326)"); + assert_eq!(pos.kind, "scalar"); + assert_eq!(pos.field_type, "Geometry"); + assert_eq!( + pos.native_type, + Some(("Geometry".to_owned(), vec!["Point".to_owned(), "4326".to_owned()])) + ); + let path = location.fields.iter().find(|f| f.name == "path").unwrap(); - assert_eq!(path.field_type, "geometry(LineString,0)"); + assert_eq!(path.field_type, "Geometry"); + assert_eq!( + path.native_type, + Some(("Geometry".to_owned(), vec!["LineString".to_owned()])) + ); + + // SevInf #1/#4: `Geography` is now a first-class PSL type rendered separately from + // `Geometry`, with its own native attribute pairing. + let footprint = location.fields.iter().find(|f| f.name == "footprint").unwrap(); + assert_eq!(footprint.field_type, "Geography"); + assert_eq!( + footprint.native_type, + Some(("Geography".to_owned(), vec!["Polygon".to_owned(), "4326".to_owned()])) + ); let schema_json = serde_json::to_value(&dmmf.schema).unwrap(); let models = schema_json @@ -41,14 +65,20 @@ fn geometry_fields_in_datamodel_and_schema_dmmf() { .find(|m| m.get("name").and_then(|n| n.as_str()) == Some("Location")) .expect("Location output type"); let fields = location_out.get("fields").and_then(|f| f.as_array()).unwrap(); - let pos_field = fields + + let pos_out = fields .iter() - .find(|f| f.get("name").and_then(|n| n.as_str()) == Some("position")); - let out_pos = pos_field.and_then(|f| f.get("outputType")).expect("position output"); - assert_eq!( - out_pos.get("type").and_then(|t| t.as_str()), - Some("geometry(Point,4326)") - ); + .find(|f| f.get("name").and_then(|n| n.as_str()) == Some("position")) + .and_then(|f| f.get("outputType")) + .expect("position output"); + assert_eq!(pos_out.get("type").and_then(|t| t.as_str()), Some("Geometry")); + + let footprint_out = fields + .iter() + .find(|f| f.get("name").and_then(|n| n.as_str()) == Some("footprint")) + .and_then(|f| f.get("outputType")) + .expect("footprint output"); + assert_eq!(footprint_out.get("type").and_then(|t| t.as_str()), Some("Geography")); } #[test] diff --git a/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs b/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs index 505450d763ad..43a48e5aeda8 100644 --- a/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs +++ b/query-compiler/query-builders/sql-query-builder/src/cursor_condition.rs @@ -442,7 +442,9 @@ fn order_definitions( OrderBy::ScalarAggregation(order_by) => cursor_order_def_aggregation_scalar(order_by, order_by_def), OrderBy::ToManyAggregation(order_by) => cursor_order_def_aggregation_rel(order_by, order_by_def), OrderBy::Relevance(order_by) => cursor_order_def_relevance(order_by, order_by_def), - OrderBy::Geometry(_) => unimplemented!("Cursor-based pagination with geometry orderBy is not yet supported"), + OrderBy::Geometry(_) => { + unimplemented!("Cursor-based pagination with geometry orderBy is not yet supported") + } }) .collect_vec() } diff --git a/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs b/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs index 0efab02812c1..2ddfe5fc82e3 100644 --- a/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs +++ b/query-compiler/query-builders/sql-query-builder/src/filter/visitor.rs @@ -5,6 +5,7 @@ use crate::{Context, model_extensions::*}; use prisma_value::Placeholder as PrismaValuePlaceholder; use psl::datamodel_connector::ConnectorCapability; +use psl::parser_database::PostgisSpatialKind; use psl::reachable_only_with_capability; use quaint::ast::concat; use quaint::ast::*; @@ -619,123 +620,86 @@ impl FilterVisitorExt for FilterVisitor { } fn visit_geometry_filter(&mut self, filter: GeometryFilter, ctx: &Context<'_>) -> ConditionTree<'static> { - let field_column = filter.field.as_column(ctx); - let field_ref = format!("\"{}\"", field_column.name); - - let srid = match &filter.condition { + // Resolve the column with the parent alias so that filters nested inside relation + // sub-queries stay qualified, matching every other `visit_*_filter` path in this file. + let field_column: Column<'static> = filter.field.aliased_col(self.parent_alias(), ctx); + let field_expr: Expression<'static> = field_column.into(); + + // Determine PostGIS spatial kind (geometry vs geography) and the field's declared SRID + // directly from the schema instead of guessing from the SRID value. + let field_spec = filter.field.geometry_spec(); + let use_geography = field_spec + .map(|spec| matches!(spec.spatial, PostgisSpatialKind::Geography)) + .unwrap_or(false); + let field_srid = field_spec.and_then(|spec| spec.srid); + + // SRID chain: explicit override (filter arg) takes precedence over the field's + // declared SRID; fall back to 0 (PostGIS "unknown") only when neither is set. + let resolved_srid = match &filter.condition { GeometryFilterCondition::Near { srid, .. } | GeometryFilterCondition::Within { srid, .. } - | GeometryFilterCondition::Intersects { srid, .. } => srid.unwrap_or(4326), + | GeometryFilterCondition::Intersects { srid, .. } => srid.or(field_srid).unwrap_or(0), }; - let use_geography = srid == 4326 || srid == 4269 || srid == 4167; - - let sql = match filter.condition { + let condition_expr = match filter.condition { GeometryFilterCondition::Near { - point, - max_distance, - .. - } => { - let (lon, lat) = point; - if use_geography { - format!( - "ST_DWithin({}::geography, ST_SetSRID(ST_MakePoint({}, {}), {})::geography, {})", - field_ref, lon, lat, srid, max_distance - ) - } else { - format!( - "ST_DWithin({}, ST_SetSRID(ST_MakePoint({}, {}), {}), {})", - field_ref, lon, lat, srid, max_distance - ) - } - } + point, max_distance, .. + } => geometry_near_condition(field_expr, point, max_distance, resolved_srid, use_geography), GeometryFilterCondition::Within { polygon, .. } => { - let wkt = format_polygon_wkt(&polygon); - let escaped_wkt = wkt.replace('\'', "''"); - format!( - "ST_Within({}, ST_GeomFromText('{}', {}))", - field_ref, escaped_wkt, srid - ) + let wkt = format_polygon_ring_wkt(&polygon); + geometry_within_condition(field_expr, wkt, resolved_srid) } GeometryFilterCondition::Intersects { geometry, .. } => { - let geom_type = geometry.get("type").and_then(|v| v.as_str()).unwrap_or(""); - let wkt = match geom_type { - "Point" => { - if let Some(coords) = geometry.get("coordinates").and_then(|v| v.as_array()) { - if coords.len() >= 2 { - format!("POINT({} {})", - coords[0].as_f64().unwrap_or(0.0), - coords[1].as_f64().unwrap_or(0.0)) - } else { - panic!("Invalid Point coordinates: expected at least 2 values, got {}", coords.len()) - } - } else { - panic!("Invalid Point geometry: missing or invalid 'coordinates' array") - } - } - "LineString" => { - if let Some(coords) = geometry.get("coordinates").and_then(|v| v.as_array()) { - let point_strs: Vec = coords.iter() - .filter_map(|p| { - p.as_array().and_then(|arr| { - if arr.len() >= 2 { - Some(format!("{} {}", - arr[0].as_f64().unwrap_or(0.0), - arr[1].as_f64().unwrap_or(0.0))) - } else { - None - } - }) - }) - .collect(); - format!("LINESTRING({})", point_strs.join(", ")) - } else { - panic!("Invalid LineString geometry: missing or invalid 'coordinates' array") - } - } - "Polygon" => { - if let Some(coords) = geometry.get("coordinates").and_then(|v| v.as_array()) { - let ring_strs: Vec = coords.iter() - .filter_map(|ring| { - ring.as_array().map(|points| { - let point_strs: Vec = points.iter() - .filter_map(|p| { - p.as_array().and_then(|arr| { - if arr.len() >= 2 { - Some(format!("{} {}", - arr[0].as_f64().unwrap_or(0.0), - arr[1].as_f64().unwrap_or(0.0))) - } else { - None - } - }) - }) - .collect(); - format!("({})", point_strs.join(", ")) - }) - }) - .collect(); - format!("POLYGON({})", ring_strs.join(", ")) - } else { - "POLYGON EMPTY".to_string() - } - } - "" => panic!("Missing 'type' field in GeoJSON geometry"), - unsupported => panic!("Unsupported GeoJSON geometry type '{}' for intersects filter. Supported types: Point, LineString, Polygon", unsupported), - }; - let escaped_wkt = wkt.replace('\'', "''"); - format!( - "ST_Intersects({}, ST_GeomFromText('{}', {}))", - field_ref, escaped_wkt, srid - ) + // `parse_geometry_intersects` already rejected GeoJSON shapes that the SQL + // builder cannot lower to a single WKT (`Multi*` / `GeometryCollection`), so + // every value reaching this point is guaranteed to produce a `Some(wkt)`. + let wkt = geometry.to_wkt().unwrap_or_else(|| { + unreachable!( + "intersects filter received unsupported GeoJSON geometry `{}`; the extractor must reject it before reaching the SQL builder", + geometry.type_tag() + ) + }); + geometry_intersects_condition(field_expr, wkt, resolved_srid) } }; - let raw_expr: Expression = Value::enum_variant(sql).raw().into(); - ConditionTree::single(raw_expr) + ConditionTree::single(condition_expr) } } +fn geometry_near_condition( + field: Expression<'static>, + point: GeoCoord, + max_distance: f64, + srid: i32, + use_geography: bool, +) -> Expression<'static> { + let point_geom = st_set_srid(st_make_point(point.x, point.y), srid as i64); + let (lhs, rhs) = if use_geography { + let lhs: Expression<'static> = geography_cast(field).into(); + let rhs: Expression<'static> = geography_cast(point_geom).into(); + (lhs, rhs) + } else { + (field, point_geom.into()) + }; + st_dwithin(lhs, rhs, max_distance).into() +} + +fn geometry_within_condition(field: Expression<'static>, wkt: String, srid: i32) -> Expression<'static> { + let geom = st_geom_from_text(wkt, srid as i64); + st_within(field, geom).into() +} + +fn geometry_intersects_condition(field: Expression<'static>, wkt: String, srid: i32) -> Expression<'static> { + let geom = st_geom_from_text(wkt, srid as i64); + st_intersects(field, geom).into() +} + +fn format_polygon_ring_wkt(positions: &[GeoCoord]) -> String { + let parts: Vec<_> = positions.iter().map(|c| format!("{} {}", c.x, c.y)).collect(); + format!("POLYGON(({}))", parts.join(", ")) +} + fn scalar_filter_aliased_cond( sf: ScalarFilter, alias: Option, @@ -1663,12 +1627,3 @@ impl JsonFilterExt for (Expression<'static>, Expression<'static>) { } } } - -fn format_polygon_wkt(polygon: &[(f64, f64)]) -> String { - let coords = polygon - .iter() - .map(|(x, y)| format!("{} {}", x, y)) - .collect::>() - .join(", "); - format!("POLYGON(({}))", coords) -} diff --git a/query-compiler/query-builders/sql-query-builder/src/ordering.rs b/query-compiler/query-builders/sql-query-builder/src/ordering.rs index 1a22d39149ba..cda1a9bba74d 100644 --- a/query-compiler/query-builders/sql-query-builder/src/ordering.rs +++ b/query-compiler/query-builders/sql-query-builder/src/ordering.rs @@ -295,18 +295,22 @@ impl OrderByBuilder { .map(|j| j.alias.to_owned()) .or_else(|| self.parent_alias.clone()); let field_column = order_by.field.as_column(ctx).opt_table(parent_table); + let field_expr: Expression<'static> = field_column.into(); + // SRID chain: explicit override on the orderBy node > field's declared SRID > 4326. + // We default to 4326 (rather than 0) here because geography casting requires a known + // geographic CRS for ST_Distance to return meters. + let field_srid = order_by.field.geometry_spec().and_then(|s| s.srid); + let srid = order_by.srid.or(field_srid).unwrap_or(4326); let (lon, lat) = order_by.point; - let srid = order_by.srid.unwrap_or(4326); - let field_ref = format!("\"{}\"", field_column.name); - - let sql = format!( - "ST_Distance(CAST({} AS geography), CAST(ST_SetSRID(ST_MakePoint({}, {}), {}) AS geography))", - field_ref, lon, lat, srid - ); - - let distance_expr: Expression = Value::enum_variant(sql).raw().into(); + // Cast both sides to geography so ST_Distance is reported in meters regardless of the + // input CRS, matching the legacy behaviour but produced through the parameterised + // Function AST rather than string concatenation. + let point_geom = st_set_srid(st_make_point(lon, lat), srid as i64); + let lhs: Expression<'static> = geography_cast(field_expr).into(); + let rhs: Expression<'static> = geography_cast(point_geom).into(); + let distance_expr: Expression<'static> = st_distance(lhs, rhs).into(); let order = Some(into_order(&order_by.sort_order, None, needs_reversed_order)); let order_definition: OrderDefinition = (distance_expr.clone(), order); diff --git a/query-compiler/query-compiler/src/data_mapper.rs b/query-compiler/query-compiler/src/data_mapper.rs index caa588b205ec..eb7f1bd22c8e 100644 --- a/query-compiler/query-compiler/src/data_mapper.rs +++ b/query-compiler/query-compiler/src/data_mapper.rs @@ -5,7 +5,10 @@ use crate::{ use bon::builder; use indexmap::IndexSet; use itertools::Itertools; -use psl::datamodel_connector::Flavour; +use psl::{ + datamodel_connector::Flavour, + parser_database::{GeometrySpec, GeometrySubtype}, +}; use query_core::{ CreateManyRecordsFields, DeleteRecordFields, Node, Query, QueryGraph, ReadQuery, UpdateManyRecordsFields, UpdateRecord, WriteQuery, @@ -17,18 +20,15 @@ use query_structure::{ use serde::Serialize; use std::{borrow::Cow, collections::HashMap, fmt}; -/// Maps a DMMF geometry field type string (e.g. `geometry(Point,4326)`) to the JSON protocol discriminator -/// consumed by `@prisma/client-engine-runtime`. -fn geometry_json_geometry_type(dmmf_type: &str) -> String { - let inner = dmmf_type.strip_prefix("geometry(").and_then(|s| s.strip_suffix(')')); - let subtype = inner - .and_then(|s| s.split(',').next()) - .map(str::trim) - .unwrap_or("Geometry"); - match subtype { - "Point" => "point".to_owned(), - "LineString" => "linestring".to_owned(), - "Polygon" => "polygon".to_owned(), +/// Maps a `GeometrySpec` to the JSON protocol discriminator consumed by +/// `@prisma/client-engine-runtime`. The discriminator is derived structurally so changes to +/// the DMMF surface form (e.g. SevInf #8 emitting just `Geometry`/`Geography`) cannot affect +/// runtime decoding. +fn geometry_json_geometry_type(spec: &GeometrySpec) -> String { + match spec.subtype { + GeometrySubtype::Point => "point".to_owned(), + GeometrySubtype::LineString => "linestring".to_owned(), + GeometrySubtype::Polygon => "polygon".to_owned(), _ => "geometry".to_owned(), } } @@ -479,8 +479,8 @@ impl From<&Type> for FieldScalarType { TypeIdentifier::Bytes => Self::Bytes { encoding: ByteArrayEncoding::default(), }, - TypeIdentifier::Geometry(dmmf) => Self::Geometry { - geometry_type: geometry_json_geometry_type(dmmf), + TypeIdentifier::Geometry(spec) => Self::Geometry { + geometry_type: geometry_json_geometry_type(spec), }, TypeIdentifier::Unsupported => Self::Unsupported, } diff --git a/query-compiler/query-compiler/tests/data/schema.prisma b/query-compiler/query-compiler/tests/data/schema.prisma index 1f02dea8f358..f6c480c6f469 100644 --- a/query-compiler/query-compiler/tests/data/schema.prisma +++ b/query-compiler/query-compiler/tests/data/schema.prisma @@ -88,7 +88,7 @@ model DataTypes { model Location { id Int @id @default(autoincrement()) name String? - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } model Patient { diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap index eaad457ca5fe..3b45d5574d72 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-combined-scalar-spatial.json.snap @@ -10,7 +10,8 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."name", "t0"."position" FROM "public"."Location" - AS "t0" WHERE (ST_DWithin("position"::geography, - ST_SetSRID(ST_MakePoint(0, 0), 4326)::geography, 10000) AND - "t0"."name"::text LIKE $1)» -params [const(String("Paris%"))] + AS "t0" WHERE (ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($1, + $2), $3), $4) AND "t0"."name"::text LIKE $5)» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326)), const(Float(BigDecimal("10000"))), + const(String("Paris%"))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap index a8e73c6547b6..b53b3c65d873 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-count-with-filter.json.snap @@ -10,6 +10,7 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."name", "t0"."position" FROM "public"."Location" - AS "t0" WHERE ST_Within("position", ST_GeomFromText('POLYGON((-1 -1, -1 - 5, 5 5, 5 -1, -1 -1))', 4326)) ORDER BY "t0"."id" ASC LIMIT $1» -params [const(BigInt(10))] + AS "t0" WHERE ST_Within("t0"."position", ST_GeomFromText($1, $2)) ORDER + BY "t0"."id" ASC LIMIT $3» +params [const(String("POLYGON((-1 -1, -1 5, 5 5, 5 -1, -1 -1))")), + const(BigInt(4326)), const(BigInt(10))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap index 98d1d2c38e1d..73cfabff5ced 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-delete-with-filter.json.snap @@ -5,6 +5,8 @@ input_file: query-compiler/query-compiler/tests/data/geometry-delete-with-filter snapshot_kind: text --- dataMap affectedRows -execute «DELETE FROM "public"."Location" WHERE ST_DWithin("position"::geography, - ST_SetSRID(ST_MakePoint(0, 0), 4326)::geography, 1000)» -params [] +execute «DELETE FROM "public"."Location" WHERE + ST_DWithin("public"."Location"."position", ST_SetSRID(ST_MakePoint($1, + $2), $3), $4)» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326)), const(Float(BigDecimal("1000")))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap index f6098f4c10c2..cc44a14db397 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-orderby.json.snap @@ -9,7 +9,10 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE - ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), - 4326)::geography, 500000) ORDER BY ST_Distance(CAST("position" AS - geography), CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) ASC» -params [] + ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($1, $2), $3), $4) + ORDER BY ST_Distance(geography("t0"."position"), + geography(ST_SetSRID(ST_MakePoint($5, $6), $7))) ASC» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326)), const(Float(BigDecimal("500000"))), + const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap index 23474d954577..438daec6eb6b 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-and-scalar-filter.json.snap @@ -9,6 +9,8 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE - (ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), - 4326)::geography, 50000) AND "t0"."id" > $1)» -params [const(BigInt(100))] + (ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($1, $2), $3), $4) + AND "t0"."id" > $5)» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326)), const(Float(BigDecimal("50000"))), + const(BigInt(100))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap index 5bfba168fd12..afc1fa71c7e6 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-custom-srid.json.snap @@ -9,6 +9,7 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE - ST_DWithin("position", ST_SetSRID(ST_MakePoint(1000000, 6000000), 3857), - 5000)» -params [] + ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($1, $2), $3), $4)» +params [const(Float(BigDecimal("1000000"))), + const(Float(BigDecimal("6000000"))), const(BigInt(3857)), + const(Float(BigDecimal("5000")))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap index 1a745b0ff6e5..cc08c927258e 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-near.json.snap @@ -9,6 +9,6 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE - ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(2.35, 48.85), - 4326)::geography, 100000)» -params [] + ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($1, $2), $3), $4)» +params [const(Float(BigDecimal("2.35"))), const(Float(BigDecimal("48.85"))), + const(BigInt(4326)), const(Float(BigDecimal("100000")))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap index 7a87b611fb3b..ea66c47e559e 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-not-near.json.snap @@ -9,6 +9,7 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE - (NOT ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), - 4326)::geography, 10000))» -params [] + (NOT ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($1, $2), $3), + $4))» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326)), const(Float(BigDecimal("10000")))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap index b7100e40150e..3e4555396614 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-or-multiple.json.snap @@ -9,7 +9,9 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE - (ST_DWithin("position"::geography, ST_SetSRID(ST_MakePoint(0, 0), - 4326)::geography, 10000) OR ST_DWithin("position"::geography, - ST_SetSRID(ST_MakePoint(10, 10), 4326)::geography, 5000))» -params [] + (ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($1, $2), $3), $4) OR + ST_DWithin("t0"."position", ST_SetSRID(ST_MakePoint($5, $6), $7), $8))» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326)), const(Float(BigDecimal("10000"))), + const(Float(BigDecimal("10"))), const(Float(BigDecimal("10"))), + const(BigInt(4326)), const(Float(BigDecimal("5000")))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap index 3b8f9993e9bc..6352df2b5491 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-filter-within.json.snap @@ -9,6 +9,6 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" WHERE - ST_Within("position", ST_GeomFromText('POLYGON((0 0, 0 2, 2 2, 2 0, 0 - 0))', 4326))» -params [] + ST_Within("t0"."position", ST_GeomFromText($1, $2))» +params [const(String("POLYGON((0 0, 0 2, 2 2, 2 0, 0 0))")), + const(BigInt(4326))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap index adf964059a5b..239f0789edda 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-multiple-orderby.json.snap @@ -9,7 +9,7 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER - BY ST_Distance(CAST("position" AS geography), - CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) ASC, "t0"."id" - DESC» -params [] + BY ST_Distance(geography("t0"."position"), + geography(ST_SetSRID(ST_MakePoint($1, $2), $3))) ASC, "t0"."id" DESC» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap index f4d717b1a431..6c2da4756a90 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-asc.json.snap @@ -9,6 +9,7 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER - BY ST_Distance(CAST("position" AS geography), - CAST(ST_SetSRID(ST_MakePoint(2.35, 48.85), 4326) AS geography)) ASC» -params [] + BY ST_Distance(geography("t0"."position"), + geography(ST_SetSRID(ST_MakePoint($1, $2), $3))) ASC» +params [const(Float(BigDecimal("2.35"))), const(Float(BigDecimal("48.85"))), + const(BigInt(4326))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap index c5048c5579f9..cfc8babbdf1f 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-distance-desc.json.snap @@ -9,6 +9,7 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER - BY ST_Distance(CAST("position" AS geography), - CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) DESC» -params [] + BY ST_Distance(geography("t0"."position"), + geography(ST_SetSRID(ST_MakePoint($1, $2), $3))) DESC» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326))] diff --git a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap index 77be8cca7955..5f6b5d43d9f8 100644 --- a/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap +++ b/query-compiler/query-compiler/tests/snapshots/queries__queries@geometry-orderby-with-limit.json.snap @@ -9,6 +9,7 @@ dataMap { position: Geometry(point)? (position) } query «SELECT "t0"."id", "t0"."position" FROM "public"."Location" AS "t0" ORDER - BY ST_Distance(CAST("position" AS geography), - CAST(ST_SetSRID(ST_MakePoint(0, 0), 4326) AS geography)) ASC LIMIT $1» -params [const(BigInt(5))] + BY ST_Distance(geography("t0"."position"), + geography(ST_SetSRID(ST_MakePoint($1, $2), $3))) ASC LIMIT $4» +params [const(Float(BigDecimal("0"))), const(Float(BigDecimal("0"))), + const(BigInt(4326)), const(BigInt(5))] diff --git a/query-compiler/query-structure/src/field/mod.rs b/query-compiler/query-structure/src/field/mod.rs index d3c7cc425ead..abdb96c15a25 100644 --- a/query-compiler/query-structure/src/field/mod.rs +++ b/query-compiler/query-structure/src/field/mod.rs @@ -9,7 +9,7 @@ pub use scalar::*; use crate::{Model, NativeTypeInstance, Zipper, parent_container::ParentContainer}; use psl::{ - parser_database::{EnumId, ExtensionTypeId, ScalarType, walkers}, + parser_database::{EnumId, ExtensionTypeId, GeometrySpec, GeometrySubtype, ScalarType, walkers}, schema_ast::ast::FieldArity, }; use std::{borrow::Cow, hash::Hash}; @@ -130,7 +130,7 @@ impl Field { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[allow(clippy::upper_case_acronyms)] pub enum TypeIdentifier { String, @@ -145,8 +145,10 @@ pub enum TypeIdentifier { Json, DateTime, Bytes, - /// String-encoded DMMF geometry type, e.g. `geometry(Point,4326)`. - Geometry(String), + /// PostGIS spatial type. The carried `GeometrySpec` records the OGC subtype, optional SRID + /// and whether the column is `geometry` or `geography`; the user-facing PSL type name + /// (`Geometry` or `Geography`) is derived via [`GeometrySpec::psl_type_name`]. + Geometry(GeometrySpec), Unsupported, } @@ -196,7 +198,7 @@ impl Type { TypeIdentifier::Json => "Json".into(), TypeIdentifier::DateTime => "DateTime".into(), TypeIdentifier::Bytes => "Bytes".into(), - TypeIdentifier::Geometry(s) => s.clone().into(), + TypeIdentifier::Geometry(spec) => spec.psl_type_name().into(), TypeIdentifier::Unsupported => "Unsupported".into(), } } @@ -291,6 +293,19 @@ impl From for TypeIdentifier { ScalarType::Json => Self::Json, ScalarType::Decimal => Self::Decimal, ScalarType::Bytes => Self::Bytes, + ScalarType::Geometry | ScalarType::Geography => { + // PostGIS scalars need the full `GeometrySpec` (subtype/SRID/spatial). Use + // `ScalarFieldRef::type_identifier()` directly so the native attribute is + // consulted; calling this `From` impl drops the spec by construction. + let spatial = st + .postgis_spatial_kind() + .expect("Geometry/Geography have a postgis_spatial_kind"); + Self::Geometry(GeometrySpec { + subtype: GeometrySubtype::Geometry, + srid: None, + spatial, + }) + } } } } diff --git a/query-compiler/query-structure/src/field/scalar.rs b/query-compiler/query-structure/src/field/scalar.rs index 254de601ba11..761c8b20afe1 100644 --- a/query-compiler/query-structure/src/field/scalar.rs +++ b/query-compiler/query-structure/src/field/scalar.rs @@ -2,15 +2,11 @@ use crate::{DefaultKind, NativeTypeInstance, ValueGenerator, ast, parent_contain use chrono::{DateTime, FixedOffset}; use psl::{ generators::{DEFAULT_CUID_VERSION, DEFAULT_UUID_VERSION}, - parser_database::{self as db, GeometrySpec, ScalarFieldType, ScalarType, walkers}, + parser_database::{self as db, GeometrySpec, GeometrySubtype, ScalarFieldType, ScalarType, walkers}, schema_ast::ast::FieldArity, }; use std::fmt::{Debug, Display}; -fn geometry_dmmf_string(spec: &GeometrySpec) -> String { - spec.postgres_sql_type() -} - pub type ScalarField = crate::Zipper; pub type ScalarFieldRef = ScalarField; @@ -101,12 +97,37 @@ impl ScalarField { } ScalarFieldType::Enum(x) => TypeIdentifier::Enum(x), ScalarFieldType::Extension(udt) => TypeIdentifier::Extension(udt), + ScalarFieldType::BuiltInScalar(scalar @ (ScalarType::Geometry | ScalarType::Geography)) => { + // Subtype/SRID live in the native attribute (`@db.Geometry(Point, 4326)` etc.). + // The bare `Geometry` / `Geography` keyword maps to the unconstrained default + // spec (`spec.spatial` = variant, no subtype/SRID), preserving the previous + // behaviour without needing a `GeometrySpec` payload on the PSL scalar. + let resolved = self.geometry_spec().unwrap_or_else(|| GeometrySpec { + subtype: GeometrySubtype::Geometry, + srid: None, + spatial: scalar + .postgis_spatial_kind() + .expect("matched only Geometry|Geography above"), + }); + TypeIdentifier::Geometry(resolved) + } ScalarFieldType::BuiltInScalar(scalar) => scalar.into(), - ScalarFieldType::Geometry(spec) => TypeIdentifier::Geometry(geometry_dmmf_string(&spec)), ScalarFieldType::Unsupported(_) => TypeIdentifier::Unsupported, } } + /// Returns the `GeometrySpec` of the field when its scalar type is `Geometry` / `Geography`. + /// + /// The native attribute (`@db.Geometry(...)` / `@db.Geography(...)`) is the single source of + /// truth for subtype/SRID/spatial kind. Bare `Geometry` / `Geography` keywords without an + /// explicit native attribute use the connector's default native type (see + /// `default_native_type_for_scalar_type`), so this method consistently returns a populated + /// spec whenever the field is a PostGIS scalar. + pub fn geometry_spec(&self) -> Option { + let nt = self.native_type()?; + nt.connector.geometry_spec_for_native_type(&nt.native_type) + } + pub fn arity(&self) -> FieldArity { match self.id { ScalarFieldId::InModel(id) => self.dm.walk(id).ast_field().arity, diff --git a/query-compiler/query-structure/src/filter/geojson.rs b/query-compiler/query-structure/src/filter/geojson.rs new file mode 100644 index 000000000000..1ec53c1e1c69 --- /dev/null +++ b/query-compiler/query-structure/src/filter/geojson.rs @@ -0,0 +1,437 @@ +//! Strongly-typed GeoJSON geometry representation used by the geometry filter pipeline. +//! +//! The geometry filter accepts user-supplied GeoJSON. Parsing it once at the extractor layer +//! into [`GeoJsonGeometry`] guarantees that downstream code (visitor, ordering, snapshots) can +//! rely on validated invariants: +//! +//! * `type` is one of the supported GeoJSON values. +//! * All coordinates are finite (`f64::is_finite`). +//! * Polygon rings are closed (first vertex == last vertex). +//! * Rings have at least four positions (i.e. the closing vertex is mandatory). +//! +//! Any violation is rejected with [`GeoJsonParseError`] so that callers can map it to a +//! user-facing input error instead of panicking later. + +use std::hash::{Hash, Hasher}; + +use thiserror::Error; + +/// A single GeoJSON 2D coordinate (longitude/easting + latitude/northing). +#[derive(Debug, Clone, Copy)] +pub struct GeoCoord { + pub x: f64, + pub y: f64, +} + +impl GeoCoord { + pub fn new(x: f64, y: f64) -> Result { + if !x.is_finite() || !y.is_finite() { + return Err(GeoJsonParseError::NonFiniteCoord { x, y }); + } + Ok(Self { x, y }) + } +} + +impl PartialEq for GeoCoord { + fn eq(&self, other: &Self) -> bool { + self.x.to_bits() == other.x.to_bits() && self.y.to_bits() == other.y.to_bits() + } +} + +impl Eq for GeoCoord {} + +impl Hash for GeoCoord { + fn hash(&self, state: &mut H) { + self.x.to_bits().hash(state); + self.y.to_bits().hash(state); + } +} + +/// Validated GeoJSON geometry subset. +/// +/// `Multi*` and `GeometryCollection` are accepted for round-tripping, but actual SQL support +/// in the visitor is currently limited to `Point`, `LineString`, and `Polygon`. The visitor +/// rejects unsupported variants explicitly rather than panicking. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum GeoJsonGeometry { + Point(GeoCoord), + LineString(Vec), + /// Polygon rings. The first ring is the outer boundary, subsequent rings are holes. + /// Every ring is guaranteed to be closed and contain ≥ 4 positions. + Polygon(Vec>), + MultiPoint(Vec), + MultiLineString(Vec>), + MultiPolygon(Vec>>), + GeometryCollection(Vec), +} + +impl GeoJsonGeometry { + /// Returns the GeoJSON `type` discriminator (matches the spec spelling). + pub fn type_tag(&self) -> &'static str { + match self { + Self::Point(_) => "Point", + Self::LineString(_) => "LineString", + Self::Polygon(_) => "Polygon", + Self::MultiPoint(_) => "MultiPoint", + Self::MultiLineString(_) => "MultiLineString", + Self::MultiPolygon(_) => "MultiPolygon", + Self::GeometryCollection(_) => "GeometryCollection", + } + } + + /// Serialise the geometry as Well-Known Text (WKT) suitable for `ST_GeomFromText`. + /// + /// Returns `None` for variants that the visitor currently does not generate WKT for + /// (`MultiPoint`, `MultiLineString`, `MultiPolygon`, `GeometryCollection`). Those variants + /// are still accepted by the parser to allow lossless round-tripping but the visitor + /// rejects them with a clear error rather than producing partially correct SQL. + pub fn to_wkt(&self) -> Option { + fn coord(c: &GeoCoord) -> String { + // Use Rust default formatting which preserves enough precision for f64 round-tripping + // and never emits scientific notation that PostGIS would mis-parse. + format!("{} {}", c.x, c.y) + } + fn ring(positions: &[GeoCoord]) -> String { + let parts: Vec<_> = positions.iter().map(coord).collect(); + format!("({})", parts.join(", ")) + } + + match self { + Self::Point(p) => Some(format!("POINT({})", coord(p))), + Self::LineString(positions) => { + if positions.is_empty() { + return Some("LINESTRING EMPTY".to_owned()); + } + let parts: Vec<_> = positions.iter().map(coord).collect(); + Some(format!("LINESTRING({})", parts.join(", "))) + } + Self::Polygon(rings) => { + if rings.is_empty() { + return Some("POLYGON EMPTY".to_owned()); + } + let parts: Vec<_> = rings.iter().map(|r| ring(r)).collect(); + Some(format!("POLYGON({})", parts.join(", "))) + } + Self::MultiPoint(_) | Self::MultiLineString(_) | Self::MultiPolygon(_) | Self::GeometryCollection(_) => { + None + } + } + } + + /// Parses a `serde_json::Value` produced from user input into a validated geometry. + /// + /// Validation covers: + /// * The `type` field must be present and a recognised GeoJSON type. + /// * The `coordinates` (or `geometries`) field must be present and shape-correct. + /// * All numeric coordinates must be finite. + /// * Polygon rings are auto-closed when the caller provides at least three distinct + /// positions; missing first/last vertex equality is repaired silently to match + /// PostGIS leniency. Rings shorter than three distinct positions are rejected. + pub fn from_serde_value(value: &serde_json::Value) -> Result { + let obj = value.as_object().ok_or(GeoJsonParseError::ExpectedObject)?; + + let type_str = obj + .get("type") + .and_then(|t| t.as_str()) + .ok_or(GeoJsonParseError::MissingType)?; + + match type_str { + "Point" => { + let coords = expect_coords(obj)?; + Self::parse_position(coords).map(Self::Point) + } + "LineString" => { + let coords = expect_coords(obj)?; + Self::parse_position_array(coords).map(Self::LineString) + } + "Polygon" => { + let coords = expect_coords(obj)?; + Self::parse_polygon_rings(coords).map(Self::Polygon) + } + "MultiPoint" => { + let coords = expect_coords(obj)?; + Self::parse_position_array(coords).map(Self::MultiPoint) + } + "MultiLineString" => { + let coords = expect_coords(obj)?; + let lines = coords + .as_array() + .ok_or(GeoJsonParseError::InvalidShape { + type_tag: "MultiLineString", + reason: "expected `coordinates` to be an array of LineString arrays", + })? + .iter() + .map(Self::parse_position_array) + .collect::, _>>()?; + Ok(Self::MultiLineString(lines)) + } + "MultiPolygon" => { + let coords = expect_coords(obj)?; + let polygons = coords + .as_array() + .ok_or(GeoJsonParseError::InvalidShape { + type_tag: "MultiPolygon", + reason: "expected `coordinates` to be an array of Polygon ring arrays", + })? + .iter() + .map(Self::parse_polygon_rings) + .collect::, _>>()?; + Ok(Self::MultiPolygon(polygons)) + } + "GeometryCollection" => { + let geometries = + obj.get("geometries") + .and_then(|g| g.as_array()) + .ok_or(GeoJsonParseError::InvalidShape { + type_tag: "GeometryCollection", + reason: "expected a `geometries` array", + })?; + let inner = geometries + .iter() + .map(Self::from_serde_value) + .collect::, _>>()?; + Ok(Self::GeometryCollection(inner)) + } + other => Err(GeoJsonParseError::UnsupportedType(other.to_owned())), + } + } + + fn parse_position(value: &serde_json::Value) -> Result { + let arr = value.as_array().ok_or(GeoJsonParseError::InvalidPosition)?; + if arr.len() < 2 { + return Err(GeoJsonParseError::InvalidPosition); + } + let x = arr[0].as_f64().ok_or(GeoJsonParseError::InvalidPosition)?; + let y = arr[1].as_f64().ok_or(GeoJsonParseError::InvalidPosition)?; + GeoCoord::new(x, y) + } + + fn parse_position_array(value: &serde_json::Value) -> Result, GeoJsonParseError> { + value + .as_array() + .ok_or(GeoJsonParseError::InvalidPositionArray)? + .iter() + .map(Self::parse_position) + .collect() + } + + fn parse_polygon_rings(value: &serde_json::Value) -> Result>, GeoJsonParseError> { + let rings = value.as_array().ok_or(GeoJsonParseError::InvalidShape { + type_tag: "Polygon", + reason: "expected `coordinates` to be an array of rings", + })?; + + let mut parsed_rings = Vec::with_capacity(rings.len()); + for ring in rings { + let mut positions = Self::parse_position_array(ring)?; + close_ring(&mut positions)?; + parsed_rings.push(positions); + } + Ok(parsed_rings) + } +} + +fn expect_coords(obj: &serde_json::Map) -> Result<&serde_json::Value, GeoJsonParseError> { + obj.get("coordinates").ok_or(GeoJsonParseError::MissingCoordinates) +} + +fn close_ring(positions: &mut Vec) -> Result<(), GeoJsonParseError> { + if positions.len() < 3 { + return Err(GeoJsonParseError::RingTooShort { len: positions.len() }); + } + let first = *positions.first().unwrap(); + let last = *positions.last().unwrap(); + if first != last { + positions.push(first); + } + if positions.len() < 4 { + return Err(GeoJsonParseError::RingTooShort { len: positions.len() }); + } + Ok(()) +} + +/// Errors that can be produced while validating GeoJSON input. All variants are user-facing +/// and translated into `InputError` by the extractor. +#[derive(Debug, Error)] +pub enum GeoJsonParseError { + #[error("GeoJSON geometry must be a JSON object")] + ExpectedObject, + #[error("GeoJSON geometry is missing the required `type` field")] + MissingType, + #[error("GeoJSON geometry is missing the required `coordinates` field")] + MissingCoordinates, + #[error("GeoJSON `{type_tag}` geometry is malformed: {reason}")] + InvalidShape { + type_tag: &'static str, + reason: &'static str, + }, + #[error("GeoJSON position must be an array of at least two finite numbers")] + InvalidPosition, + #[error("GeoJSON expected an array of positions")] + InvalidPositionArray, + #[error("GeoJSON polygon ring must contain at least 3 distinct positions (got {len})")] + RingTooShort { len: usize }, + #[error("GeoJSON coordinates must be finite numbers (got [{x}, {y}])")] + NonFiniteCoord { x: f64, y: f64 }, + #[error("Unsupported GeoJSON geometry type `{0}`")] + UnsupportedType(String), +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[track_caller] + fn parse_err(value: serde_json::Value) -> GeoJsonParseError { + GeoJsonGeometry::from_serde_value(&value).expect_err("expected parse error") + } + + #[test] + fn parses_point() { + let geom = GeoJsonGeometry::from_serde_value(&json!({ + "type": "Point", + "coordinates": [1.0, 2.0] + })) + .unwrap(); + assert!(matches!(geom, GeoJsonGeometry::Point(GeoCoord { x, y }) if x == 1.0 && y == 2.0)); + } + + #[test] + fn rejects_non_object_input() { + assert!(matches!(parse_err(json!([1, 2])), GeoJsonParseError::ExpectedObject)); + } + + #[test] + fn rejects_missing_type() { + assert!(matches!( + parse_err(json!({ "coordinates": [1.0, 2.0] })), + GeoJsonParseError::MissingType + )); + } + + #[test] + fn rejects_unknown_type() { + assert!(matches!( + parse_err(json!({ "type": "Triangle", "coordinates": [1.0, 2.0] })), + GeoJsonParseError::UnsupportedType(ref t) if t == "Triangle" + )); + } + + #[test] + fn rejects_missing_coordinates() { + assert!(matches!( + parse_err(json!({ "type": "Point" })), + GeoJsonParseError::MissingCoordinates + )); + } + + #[test] + fn rejects_short_position() { + assert!(matches!( + parse_err(json!({ "type": "Point", "coordinates": [1.0] })), + GeoJsonParseError::InvalidPosition + )); + } + + #[test] + fn rejects_non_finite_coordinate() { + // Non-finite floats cannot survive a JSON round-trip, but the parser still validates + // them defensively because callers can synthesise `GeoCoord` instances directly. Exercise + // the guard through `GeoCoord::new` to keep the invariant tested. + assert!(matches!( + GeoCoord::new(f64::NAN, 0.0), + Err(GeoJsonParseError::NonFiniteCoord { .. }) + )); + assert!(matches!( + GeoCoord::new(0.0, f64::INFINITY), + Err(GeoJsonParseError::NonFiniteCoord { .. }) + )); + assert!(matches!( + GeoCoord::new(0.0, f64::NEG_INFINITY), + Err(GeoJsonParseError::NonFiniteCoord { .. }) + )); + } + + #[test] + fn auto_closes_open_polygon_ring() { + let geom = GeoJsonGeometry::from_serde_value(&json!({ + "type": "Polygon", + "coordinates": [[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]] + })) + .unwrap(); + let GeoJsonGeometry::Polygon(rings) = geom else { + panic!("expected polygon"); + }; + let ring = &rings[0]; + assert_eq!(ring.len(), 4, "open ring should be auto-closed"); + assert_eq!(ring.first().unwrap(), ring.last().unwrap()); + } + + #[test] + fn accepts_already_closed_polygon_ring() { + let geom = GeoJsonGeometry::from_serde_value(&json!({ + "type": "Polygon", + "coordinates": [[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]]] + })) + .unwrap(); + let GeoJsonGeometry::Polygon(rings) = geom else { + panic!("expected polygon"); + }; + assert_eq!(rings[0].len(), 4); + } + + #[test] + fn rejects_polygon_ring_too_short() { + assert!(matches!( + parse_err(json!({ + "type": "Polygon", + "coordinates": [[[0.0, 0.0], [1.0, 1.0]]] + })), + GeoJsonParseError::RingTooShort { len: 2 } + )); + } + + #[test] + fn point_wkt_round_trips() { + let geom = GeoJsonGeometry::from_serde_value(&json!({ + "type": "Point", + "coordinates": [2.35, 48.85] + })) + .unwrap(); + assert_eq!(geom.to_wkt().as_deref(), Some("POINT(2.35 48.85)")); + } + + #[test] + fn polygon_wkt_uses_closed_ring() { + let geom = GeoJsonGeometry::from_serde_value(&json!({ + "type": "Polygon", + "coordinates": [[[0.0, 0.0], [0.0, 2.0], [2.0, 2.0], [2.0, 0.0]]] + })) + .unwrap(); + assert_eq!(geom.to_wkt().as_deref(), Some("POLYGON((0 0, 0 2, 2 2, 2 0, 0 0))")); + } + + #[test] + fn multi_variants_have_no_wkt() { + let geom = GeoJsonGeometry::from_serde_value(&json!({ + "type": "MultiPoint", + "coordinates": [[0.0, 0.0], [1.0, 1.0]] + })) + .unwrap(); + assert!(geom.to_wkt().is_none()); + } + + #[test] + fn geometry_collection_round_trips() { + let geom = GeoJsonGeometry::from_serde_value(&json!({ + "type": "GeometryCollection", + "geometries": [ + { "type": "Point", "coordinates": [0.0, 0.0] }, + { "type": "Point", "coordinates": [1.0, 1.0] } + ] + })) + .unwrap(); + assert!(matches!(geom, GeoJsonGeometry::GeometryCollection(ref inner) if inner.len() == 2)); + } +} diff --git a/query-compiler/query-structure/src/filter/geometry.rs b/query-compiler/query-structure/src/filter/geometry.rs index 0ce7007bf21a..d9a1c4764bc9 100644 --- a/query-compiler/query-structure/src/filter/geometry.rs +++ b/query-compiler/query-structure/src/filter/geometry.rs @@ -15,16 +15,18 @@ impl PartialEq for GeometryFilter { #[derive(Debug, Clone)] pub enum GeometryFilterCondition { Near { - point: (f64, f64), + point: GeoCoord, max_distance: f64, srid: Option, }, Within { - polygon: Vec<(f64, f64)>, + /// Polygon ring (closed: first vertex == last vertex, length ≥ 4) validated at + /// extraction time. + polygon: Vec, srid: Option, }, Intersects { - geometry: serde_json::Value, + geometry: GeoJsonGeometry, srid: Option, }, } @@ -43,22 +45,17 @@ impl PartialEq for GeometryFilterCondition { max_distance: d2, srid: s2, }, - ) => { - p1.0.to_bits() == p2.0.to_bits() - && p1.1.to_bits() == p2.1.to_bits() - && d1.to_bits() == d2.to_bits() - && s1 == s2 - } + ) => p1 == p2 && d1.to_bits() == d2.to_bits() && s1 == s2, ( - GeometryFilterCondition::Within { polygon: poly1, srid: s1 }, - GeometryFilterCondition::Within { polygon: poly2, srid: s2 }, - ) => { - s1 == s2 - && poly1.len() == poly2.len() - && poly1.iter().zip(poly2.iter()).all(|((x1, y1), (x2, y2))| { - x1.to_bits() == x2.to_bits() && y1.to_bits() == y2.to_bits() - }) - } + GeometryFilterCondition::Within { + polygon: poly1, + srid: s1, + }, + GeometryFilterCondition::Within { + polygon: poly2, + srid: s2, + }, + ) => s1 == s2 && poly1 == poly2, ( GeometryFilterCondition::Intersects { geometry: g1, srid: s1 }, GeometryFilterCondition::Intersects { geometry: g2, srid: s2 }, @@ -72,24 +69,24 @@ impl std::hash::Hash for GeometryFilter { fn hash(&self, state: &mut H) { self.field.hash(state); match &self.condition { - GeometryFilterCondition::Near { point, max_distance, srid } => { + GeometryFilterCondition::Near { + point, + max_distance, + srid, + } => { "Near".hash(state); - point.0.to_bits().hash(state); - point.1.to_bits().hash(state); + point.hash(state); max_distance.to_bits().hash(state); srid.hash(state); } GeometryFilterCondition::Within { polygon, srid } => { "Within".hash(state); - for (x, y) in polygon { - x.to_bits().hash(state); - y.to_bits().hash(state); - } + polygon.hash(state); srid.hash(state); } GeometryFilterCondition::Intersects { geometry, srid } => { "Intersects".hash(state); - geometry.to_string().hash(state); + geometry.hash(state); srid.hash(state); } } diff --git a/query-compiler/query-structure/src/filter/mod.rs b/query-compiler/query-structure/src/filter/mod.rs index 187a921cf69b..7929f4ad0339 100644 --- a/query-compiler/query-structure/src/filter/mod.rs +++ b/query-compiler/query-structure/src/filter/mod.rs @@ -7,6 +7,7 @@ mod compare; mod composite; +mod geojson; mod geometry; mod into_filter; mod json; @@ -17,6 +18,7 @@ mod scalar; pub use compare::*; pub use composite::*; +pub use geojson::*; pub use geometry::*; pub use into_filter::*; pub use json::*; diff --git a/query-compiler/query-structure/src/prisma_value_ext.rs b/query-compiler/query-structure/src/prisma_value_ext.rs index 946f31e393b9..6226e78ee18a 100644 --- a/query-compiler/query-structure/src/prisma_value_ext.rs +++ b/query-compiler/query-structure/src/prisma_value_ext.rs @@ -10,7 +10,7 @@ pub(crate) trait PrismaValueExtensions { impl PrismaValueExtensions for PrismaValue { // Todo this is not exhaustive for now. fn coerce(self, to_type: &Type) -> crate::Result { - let coerced = match (self, to_type.id.clone()) { + let coerced = match (self, to_type.id) { // Trivial cases (PrismaValue::Null, _) => PrismaValue::Null, (val @ PrismaValue::String(_), TypeIdentifier::String) => val, diff --git a/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs b/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs index 48d099ec4a39..da7e9072d2af 100644 --- a/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs +++ b/query-compiler/schema/src/build/input_types/fields/field_filter_types.rs @@ -226,7 +226,7 @@ fn full_scalar_filter_type( let scalar_type_name = ctx .internal_data_model .clone() - .zip(typ.clone()) + .zip(typ) .type_name() .into_owned(); let type_name = ctx.connector.scalar_filter_name(scalar_type_name, native_type_name); @@ -241,7 +241,7 @@ fn full_scalar_filter_type( let mut object = init_input_object_type(ident); object.set_fields(move || { - let mapped_scalar_type = map_scalar_input_type(ctx, typ.clone(), list); + let mapped_scalar_type = map_scalar_input_type(ctx, typ, list); let mut fields: Vec<_> = match &typ { TypeIdentifier::String | TypeIdentifier::UUID => equality_filters(mapped_scalar_type.clone(), nullable) .chain(inclusion_filters(ctx, mapped_scalar_type.clone(), nullable)) @@ -295,7 +295,7 @@ fn full_scalar_filter_type( fields.push(not_filter_field( ctx, - typ.clone(), + typ, native_type.clone(), mapped_scalar_type, nullable, @@ -313,7 +313,7 @@ fn full_scalar_filter_type( )); if typ.is_numeric() { - let avg_type = map_avg_type_ident(typ.clone()); + let avg_type = map_avg_type_ident(typ); fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_AVG, @@ -325,7 +325,7 @@ fn full_scalar_filter_type( fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_SUM, - typ.clone(), + typ, nullable, list, )); @@ -335,7 +335,7 @@ fn full_scalar_filter_type( fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_MIN, - typ.clone(), + typ, nullable, list, )); @@ -343,7 +343,7 @@ fn full_scalar_filter_type( fields.push(aggregate_filter_field( ctx, aggregations::UNDERSCORE_MAX, - typ.clone(), + typ, nullable, list, )); @@ -675,7 +675,12 @@ fn geometry_filters<'a>() -> impl Iterator> { vec![ simple_input_field(filters::NEAR, InputType::object(geometry_near_input()), None).optional(), simple_input_field(filters::WITHIN, InputType::object(geometry_within_input()), None).optional(), - simple_input_field(filters::INTERSECTS, InputType::object(geometry_intersects_input()), None).optional(), + simple_input_field( + filters::INTERSECTS, + InputType::object(geometry_intersects_input()), + None, + ) + .optional(), ] .into_iter() } diff --git a/query-compiler/schema/src/build/input_types/mod.rs b/query-compiler/schema/src/build/input_types/mod.rs index d3df06e9c82e..63ffc39a113d 100644 --- a/query-compiler/schema/src/build/input_types/mod.rs +++ b/query-compiler/schema/src/build/input_types/mod.rs @@ -23,7 +23,7 @@ fn map_scalar_input_type(ctx: &'_ QuerySchema, typ: TypeIdentifier, list: bool) TypeIdentifier::Extension(_) => unreachable!("No extension field should reach this path"), TypeIdentifier::Bytes => InputType::bytes(), TypeIdentifier::BigInt => InputType::bigint(), - TypeIdentifier::Geometry(s) => InputType::Scalar(ScalarType::Geometry(s)), + TypeIdentifier::Geometry(spec) => InputType::Scalar(ScalarType::Geometry(spec)), TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), }; diff --git a/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs b/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs index 287d8c5109e7..c11bffab020d 100644 --- a/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs +++ b/query-compiler/schema/src/build/input_types/objects/order_by_objects.rs @@ -298,12 +298,14 @@ fn geometry_order_by_object_type<'a>() -> InputObjectType<'a> { let mut object = init_input_object_type(ident); object.set_fields(|| { - vec![simple_input_field( - ordering::DISTANCE_FROM, - InputType::object(geometry_distance_from_input()), - None, - ) - .optional()] + vec![ + simple_input_field( + ordering::DISTANCE_FROM, + InputType::object(geometry_distance_from_input()), + None, + ) + .optional(), + ] }); object diff --git a/query-compiler/schema/src/build/output_types/field.rs b/query-compiler/schema/src/build/output_types/field.rs index 50ef922e4b21..7f21c7d435b9 100644 --- a/query-compiler/schema/src/build/output_types/field.rs +++ b/query-compiler/schema/src/build/output_types/field.rs @@ -40,7 +40,7 @@ pub(crate) fn map_scalar_output_type<'a>(ctx: &'a QuerySchema, typ: &TypeIdentif TypeIdentifier::Int => OutputType::int(), TypeIdentifier::Bytes => OutputType::bytes(), TypeIdentifier::BigInt => OutputType::bigint(), - TypeIdentifier::Geometry(s) => OutputType::geometry(s.clone()), + TypeIdentifier::Geometry(spec) => OutputType::geometry(*spec), TypeIdentifier::Unsupported => unreachable!("No unsupported field should reach this path"), }; diff --git a/query-compiler/schema/src/output_types.rs b/query-compiler/schema/src/output_types.rs index a27ee5db56ad..eb3d6788c5b6 100644 --- a/query-compiler/schema/src/output_types.rs +++ b/query-compiler/schema/src/output_types.rs @@ -76,8 +76,8 @@ impl<'a> OutputType<'a> { InnerOutputType::Scalar(ScalarType::Bytes) } - pub(crate) fn geometry(dmmf_type: String) -> InnerOutputType<'a> { - InnerOutputType::Scalar(ScalarType::Geometry(dmmf_type)) + pub(crate) fn geometry(spec: db::GeometrySpec) -> InnerOutputType<'a> { + InnerOutputType::Scalar(ScalarType::Geometry(spec)) } /// Attempts to recurse through the type until an object type is found. diff --git a/query-compiler/schema/src/query_schema.rs b/query-compiler/schema/src/query_schema.rs index 029e0065044c..37cbd78f2801 100644 --- a/query-compiler/schema/src/query_schema.rs +++ b/query-compiler/schema/src/query_schema.rs @@ -373,8 +373,10 @@ pub enum ScalarType { JsonList, UUID, Bytes, - /// DMMF string form, e.g. `geometry(Point,4326)`. - Geometry(String), + /// PostGIS spatial type. The `GeometrySpec` records the OGC subtype, SRID and whether the + /// column is `geometry` or `geography`. The DMMF surface form is the PSL keyword + /// (`Geometry` or `Geography`); structured arguments are reported via `native_type`. + Geometry(db::GeometrySpec), } impl fmt::Display for ScalarType { @@ -392,7 +394,7 @@ impl fmt::Display for ScalarType { ScalarType::UUID => "UUID", ScalarType::JsonList => "Json", ScalarType::Bytes => "Bytes", - ScalarType::Geometry(s) => s.as_str(), + ScalarType::Geometry(spec) => spec.psl_type_name(), }; f.write_str(typ) diff --git a/query-engine/connectors/mongodb-query-connector/src/filter.rs b/query-engine/connectors/mongodb-query-connector/src/filter.rs index 0039ca148517..cd4b8a168057 100644 --- a/query-engine/connectors/mongodb-query-connector/src/filter.rs +++ b/query-engine/connectors/mongodb-query-connector/src/filter.rs @@ -94,6 +94,11 @@ impl MongoFilterVisitor { Filter::Aggregation(filter) => self.visit_aggregation_filter(filter)?, Filter::Composite(filter) => self.visit_composite_filter(filter)?, Filter::BoolFilter(_) => unimplemented!("MongoDB boolean filter."), + Filter::Geometry(_) => { + return Err(MongoError::Unsupported( + "Geometry filters are not supported on MongoDB".to_string(), + )); + } }; Ok(filter_pair) diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs index f1c62e5feecd..71bd931c790e 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/renderer.rs @@ -585,6 +585,13 @@ fn render_column_type_postgres(col: TableColumnWalker<'_>) -> Cow<'static, str> } return format!("{}({})", name, args.iter().format(", ")).into(); } + PostgresType::Postgis(postgis) => { + // Emit `geometry(Subtype, SRID)` or `geography(Subtype, SRID)` directly. The lower- + // case spelling matches what PostGIS uses in its catalogs so the migration DDL + // round-trips cleanly through introspection. + let spec = postgis.to_geometry_spec(); + return spec.postgres_sql_type().into(); + } }; let tpe: Cow<'_, str> = match native_type { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs index f5eed3232cde..49726c838ba8 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/schema_differ.rs @@ -383,17 +383,16 @@ fn postgres_column_type_change(columns: MigrationPair>) -> let from_list_to_scalar = columns.previous.arity().is_list() && !columns.next.arity().is_list(); let from_scalar_to_list = !columns.previous.arity().is_list() && columns.next.arity().is_list(); - match (columns.previous.column_type_family(), columns.next.column_type_family()) { - (ColumnTypeFamily::Geometry(prev), ColumnTypeFamily::Geometry(next)) => { - if from_list_to_scalar || from_scalar_to_list { - return Some(NotCastable); - } - if prev == next { - return None; - } - return Some(RiskyCast); + if let (ColumnTypeFamily::Geometry(prev), ColumnTypeFamily::Geometry(next)) = + (columns.previous.column_type_family(), columns.next.column_type_family()) + { + if from_list_to_scalar || from_scalar_to_list { + return Some(NotCastable); + } + if prev == next { + return None; } - _ => {} + return Some(RiskyCast); } match (previous_type, next_type) { diff --git a/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs b/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs index 8d1a5aab88f0..008c1799117c 100644 --- a/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs +++ b/schema-engine/connectors/sql-schema-connector/src/introspection/introspection_pair/scalar_field.rs @@ -2,7 +2,7 @@ use crate::introspection::sanitize_datamodel_names; use either::Either; use psl::{ datamodel_connector::{Flavour, walker_ext_traits::IndexWalkerExt}, - parser_database::{ExtensionTypeEntry, ScalarFieldType, walkers}, + parser_database::{ExtensionTypeEntry, PostgisSpatialKind, ScalarFieldType, walkers}, schema_ast::ast::WithDocumentation, }; use sql::ColumnArity; @@ -108,14 +108,16 @@ impl<'a> ScalarFieldPair<'a> { sql::ColumnTypeFamily::Json => Cow::from("Json"), sql::ColumnTypeFamily::Uuid => Cow::from("String"), sql::ColumnTypeFamily::Geometry(spec) => { - use std::fmt::Write; - let mut out = String::from("Geometry("); - out.push_str(spec.subtype.as_str()); - if let Some(srid) = spec.srid { - write!(&mut out, ", {srid}").unwrap(); + // PSL no longer uses the inline `Geometry(Subtype, SRID)` form. Subtype, SRID + // and the planar/geodetic kind are expressed via the @db.Geometry / @db.Geography + // native attributes that the renderer attaches from `native_type()`. The PSL + // keyword still has to match the spatial kind, otherwise the validator emits a + // "Native type Geography is not compatible with declared field type Geometry" + // error and the round-trip breaks for `geography` columns. + match spec.spatial { + PostgisSpatialKind::Geometry => Cow::Borrowed("Geometry"), + PostgisSpatialKind::Geography => Cow::Borrowed("Geography"), } - out.push(')'); - Cow::Owned(out) } sql::ColumnTypeFamily::Enum(id) => self.context.enum_prisma_name(*id).prisma_name(), &sql::ColumnTypeFamily::Udt(id) => self @@ -163,7 +165,17 @@ impl<'a> ScalarFieldPair<'a> { sql::ColumnTypeFamily::Json => psl::parser_database::ScalarType::Json, sql::ColumnTypeFamily::Uuid => psl::parser_database::ScalarType::String, sql::ColumnTypeFamily::Binary => psl::parser_database::ScalarType::Bytes, - sql::ColumnTypeFamily::Geometry(spec) => return Some(ScalarFieldType::Geometry(*spec)), + sql::ColumnTypeFamily::Geometry(spec) => match spec.spatial { + // Map the catalog-level spatial kind back to the matching PSL keyword. Subtype + // and SRID surface as the `@db.Geometry(...)` / `@db.Geography(...)` native + // attribute, just like every other parametrized scalar. + PostgisSpatialKind::Geometry => { + return Some(ScalarFieldType::BuiltInScalar(psl::parser_database::ScalarType::Geometry)); + } + PostgisSpatialKind::Geography => { + return Some(ScalarFieldType::BuiltInScalar(psl::parser_database::ScalarType::Geography)); + } + }, sql::ColumnTypeFamily::Udt(_) => { let entry = self.extension_type()?; return Some(ScalarFieldType::Extension(entry.id)); diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs index 4909704ab39c..c68a27156133 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs @@ -211,16 +211,21 @@ fn parse_typ_opt<'a>( } let parsed_typ = ScalarType::try_from_str(typ.inner(), false) - .map(|st| match st { - ScalarType::Int => ColumnType::Int32, - ScalarType::BigInt => ColumnType::Int64, - ScalarType::Float => ColumnType::Float, - ScalarType::Boolean => ColumnType::Boolean, - ScalarType::String => ColumnType::Text, - ScalarType::DateTime => ColumnType::DateTime, - ScalarType::Json => ColumnType::Json, - ScalarType::Bytes => ColumnType::Bytes, - ScalarType::Decimal => ColumnType::Numeric, + .and_then(|st| match st { + ScalarType::Int => Some(ColumnType::Int32), + ScalarType::BigInt => Some(ColumnType::Int64), + ScalarType::Float => Some(ColumnType::Float), + ScalarType::Boolean => Some(ColumnType::Boolean), + ScalarType::String => Some(ColumnType::Text), + ScalarType::DateTime => Some(ColumnType::DateTime), + ScalarType::Json => Some(ColumnType::Json), + ScalarType::Bytes => Some(ColumnType::Bytes), + ScalarType::Decimal => Some(ColumnType::Numeric), + // PostGIS spatial scalars aren't bindable as raw-SQL parameter placeholders + // through this doc-comment syntax (callers should pass WKB via `{Bytes}` / + // `{String}` instead). Surface as "unknown" so the caller-side error message + // below applies uniformly. + ScalarType::Geometry | ScalarType::Geography => None, }) .map(ParsedParamType::ColumnType) .or_else(|| { diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs index ee830d464240..f418f09d486d 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator.rs @@ -441,24 +441,36 @@ fn push_column_for_scalar_field(field: ScalarFieldWalker<'_>, table_id: sql::Tab ScalarFieldType::CompositeType(_) => { push_column_for_builtin_scalar_type(field, ScalarType::Json, table_id, ctx) } + ScalarFieldType::BuiltInScalar(ScalarType::Geometry | ScalarType::Geography) => { + push_column_for_geometry_field(field, table_id, ctx) + } ScalarFieldType::BuiltInScalar(scalar_type) => { push_column_for_builtin_scalar_type(field, scalar_type, table_id, ctx) } - ScalarFieldType::Geometry(spec) => push_column_for_geometry_field(field, spec, table_id, ctx), ScalarFieldType::Unsupported(_) => push_column_for_model_unsupported_scalar_field(field, table_id, ctx), } } -fn push_column_for_geometry_field( - field: ScalarFieldWalker<'_>, - spec: GeometrySpec, - table_id: sql::TableId, - ctx: &mut Context<'_>, -) { +fn push_column_for_geometry_field(field: ScalarFieldWalker<'_>, table_id: sql::TableId, ctx: &mut Context<'_>) { let connector = ctx.flavour.datamodel_connector(); + // Resolve `GeometrySpec` (subtype/SRID/spatial kind) from the native attribute first, then + // fall back to the connector default (unconstrained `geometry` / `geography`). The field type + // itself no longer carries a spec — it only encodes the spatial kind via the `ScalarType` + // variant, just like `String @db.VarChar(300)` carries the length via the native attribute. let native_type = field .native_type_instance(connector) - .or_else(|| connector.default_native_type_for_scalar_type(&ScalarFieldType::Geometry(spec), ctx.datamodel)); + .or_else(|| connector.default_native_type_for_scalar_type(&field.scalar_field_type(), ctx.datamodel)); + let spec = native_type + .as_ref() + .and_then(|nt| connector.geometry_spec_for_native_type(nt)) + .unwrap_or(GeometrySpec { + subtype: db::GeometrySubtype::Geometry, + srid: None, + spatial: field + .scalar_field_type() + .postgis_spatial_kind() + .expect("push_column_for_geometry_field invoked on non-spatial scalar"), + }); let default = field.default_value().map(|def| { sql::DefaultValue::db_generated::(unwrap_dbgenerated(def.value())) @@ -627,6 +639,13 @@ fn push_column_for_builtin_scalar_type( ScalarType::Bytes => sql::ColumnTypeFamily::Binary, ScalarType::Decimal => sql::ColumnTypeFamily::Decimal, ScalarType::BigInt => sql::ColumnTypeFamily::BigInt, + // PostGIS scalars are dispatched by `push_column_for_scalar_field` to their own + // helper (`push_column_for_geometry_field`); reaching here would be a routing bug. + ScalarType::Geometry | ScalarType::Geography => { + unreachable!( + "PostGIS scalar types must be dispatched through push_column_for_geometry_field" + ) + } }; let native_type = field.native_type_instance(connector).or_else(|| { diff --git a/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs b/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs index 7e61cc9e756d..546cbf25fda1 100644 --- a/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs +++ b/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs @@ -21,9 +21,58 @@ async fn introspect_geometry_columns(api: &mut TestApi) -> TestResult { assert!(schema.contains("extensions = [postgis")); assert!(schema.contains("model locations")); - assert!(schema.contains("position Geometry(Point, 4326)?")); - assert!(schema.contains("path") && schema.contains("Geometry(LineString)")); - assert!(schema.contains("area Geometry(Polygon, 3857)")); + assert!(schema.contains("position Geometry?") && schema.contains("@db.Geometry(Point, 4326)")); + assert!(schema.contains("path") && schema.contains("@db.Geometry(LineString)")); + assert!(schema.contains("area Geometry") && schema.contains("@db.Geometry(Polygon, 3857)")); + + Ok(()) +} + +// SevInf #1/#4: PostGIS `geography` columns must surface as `Geography @db.Geography(...)` in the +// re-introspected schema. Until this audit, the introspection layer collapsed both spatial kinds +// into the `Geometry` PSL keyword, which then collided with the `@db.Geography` native attribute +// validator ("Native type Geography is not compatible with declared field type Geometry"). This +// regression test pins the planar/geodetic split end-to-end against a live PostGIS catalog. +#[test_connector(tags(Postgres), exclude(CockroachDb), preview_features("postgresqlExtensions"))] +async fn introspect_geography_columns(api: &mut TestApi) -> TestResult { + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS postgis").await; + api.raw_cmd(indoc! {r#" + CREATE TABLE places ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + location geography(Point, 4326), + region geography(Polygon, 4326) NOT NULL, + footprint geography + ); + "#}) + .await; + + let schema = api.introspect().await?; + + assert!(schema.contains("extensions = [postgis")); + assert!(schema.contains("model places")); + + // Field type MUST be `Geography` (not `Geometry`) so the native attribute pairing validates. + assert!( + schema.contains("location Geography?") && schema.contains("@db.Geography(Point, 4326)"), + "expected `location Geography? ... @db.Geography(Point, 4326)`, got:\n{schema}", + ); + assert!( + schema.contains("region Geography") && schema.contains("@db.Geography(Polygon, 4326)"), + "expected `region Geography ... @db.Geography(Polygon, 4326)`, got:\n{schema}", + ); + // Untyped `geography` (no typmod) keeps the bare native form without subtype/SRID args. + assert!( + schema.contains("footprint Geography?"), + "expected `footprint Geography?`, got:\n{schema}", + ); + + // The introspector must NEVER emit `Geometry @db.Geography(...)` — that combo is rejected by + // PSL validation and used to be the silent failure mode for `geography` columns. + assert!( + !schema.contains("Geometry @db.Geography") && !schema.contains("Geometry? @db.Geography"), + "introspected schema must not pair `Geometry` keyword with `@db.Geography`, got:\n{schema}", + ); Ok(()) } diff --git a/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs b/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs index 022e91ebab09..ff3e26485597 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs @@ -9,7 +9,7 @@ fn create_table_with_geometry(api: TestApi) { let dm = indoc! {r#" model Location { id Int @id @default(autoincrement()) - position Geometry(Point, 4326)? + position Geometry? @db.Geometry(Point, 4326) } "#}; @@ -32,7 +32,7 @@ fn alter_geometry_srid(api: TestApi) { let schema1 = indoc! {r#" model Location { id Int @id - position Geometry(Point, 4326) + position Geometry @db.Geometry(Point, 4326) } "#}; @@ -41,7 +41,7 @@ fn alter_geometry_srid(api: TestApi) { let schema2 = indoc! {r#" model Location { id Int @id - position Geometry(Point, 3857) + position Geometry @db.Geometry(Point, 3857) } "#}; @@ -62,8 +62,8 @@ fn geometry_round_trip(mut api: TestApi) { let dm = indoc! {r#" model Location { id Int @id - position Geometry(Point, 4326)? - path Geometry(LineString, 4326)? + position Geometry? @db.Geometry(Point, 4326) + path Geometry? @db.Geometry(LineString, 4326) } "#}; @@ -83,6 +83,78 @@ fn geometry_round_trip(mut api: TestApi) { .unwrap() .into_single_datamodel(); - assert!(introspected.contains("Geometry(Point, 4326)")); - assert!(introspected.contains("Geometry(LineString, 4326)")); + assert!(introspected.contains("@db.Geometry(Point, 4326)")); + assert!(introspected.contains("@db.Geometry(LineString, 4326)")); +} + +// SevInf #1/#4: `Geography` is a first-class PSL keyword paired with `@db.Geography(...)`. +// PostGIS persists the geodetic kind in `pg_attribute.atttypid`, so a pushed `Geography` column +// must come back as `geography(...)` and the renderer must round-trip it intact. +#[test_connector(tags(Postgres), exclude(CockroachDb))] +fn create_table_with_geography(api: TestApi) { + let dm = indoc! {r#" + model Place { + id Int @id @default(autoincrement()) + region Geography? @db.Geography(Polygon, 4326) + } + "#}; + + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS postgis"); + + api.schema_push_w_datasource(dm).send().assert_green(); + + let connector = psl::builtin_connectors::POSTGRES; + api.assert_schema().assert_table("Place", |table| { + table.assert_column("region", |col| { + col.assert_native_type("geography(Polygon,4326)", connector) + }) + }); +} + +// SevInf #1/#4 + introspection regression: Re-introspecting a `geography(...)` column must +// produce `Geography @db.Geography(...)`, not `Geometry @db.Geography(...)`. The latter is +// rejected at PSL validation time ("Native type Geography is not compatible with declared field +// type Geometry"), so a missing pairing here would silently break round-trip migrations. +#[test_connector(tags(Postgres), exclude(CockroachDb), preview_features("postgresqlExtensions"))] +fn geography_round_trip(mut api: TestApi) { + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS postgis"); + + let dm = indoc! {r#" + model Place { + id Int @id + location Geography? @db.Geography(Point, 4326) + region Geography? @db.Geography(Polygon, 4326) + } + "#}; + + api.schema_push_w_datasource(dm).send().assert_green(); + + let schema = api.datamodel_with_provider(dm); + let previous_schema = psl::validate_without_extensions(schema.into()); + let mut ctx = IntrospectionContext::new( + previous_schema, + CompositeTypeDepth::Infinite, + None, + std::path::PathBuf::new(), + ); + ctx.render_config = false; + + let introspected = tok(api.connector.introspect(&ctx, &NoExtensionTypes)) + .unwrap() + .into_single_datamodel(); + + // The renderer must emit the `Geography` PSL keyword, NOT `Geometry`, otherwise PSL + // validation refuses the resulting schema (see the keyword-mismatch validation fixture). + assert!( + introspected.contains("Geography? @db.Geography(Point, 4326)"), + "expected `Geography? @db.Geography(Point, 4326)` in re-introspected schema, got:\n{introspected}", + ); + assert!( + introspected.contains("Geography? @db.Geography(Polygon, 4326)"), + "expected `Geography? @db.Geography(Polygon, 4326)` in re-introspected schema, got:\n{introspected}", + ); + assert!( + !introspected.contains("Geometry @db.Geography") && !introspected.contains("Geometry? @db.Geography"), + "re-introspected schema must NOT pair `Geometry` field type with `@db.Geography` native attribute, got:\n{introspected}", + ); } diff --git a/schema-engine/sql-schema-describer/src/postgres.rs b/schema-engine/sql-schema-describer/src/postgres.rs index ebcd7c7dfa7e..08b1cea6c335 100644 --- a/schema-engine/sql-schema-describer/src/postgres.rs +++ b/schema-engine/sql-schema-describer/src/postgres.rs @@ -13,7 +13,7 @@ use enumflags2::BitFlags; use indexmap::IndexMap; use indoc::indoc; use psl::{ - builtin_connectors::{CockroachType, KnownPostgresType, PostgresType}, + builtin_connectors::{CockroachType, GeometryNativeArgs, KnownPostgresType, PostgisNativeType, PostgresType}, datamodel_connector::NativeTypeInstance, parser_database::{GeometrySpec, GeometrySubtype, PostgisSpatialKind}, }; @@ -1534,7 +1534,8 @@ fn index_from_row( } fn map_geometry_subtype(pg_name: &str) -> GeometrySubtype { - match pg_name.trim().to_uppercase().as_str() { + let normalized = pg_name.trim().to_uppercase(); + match normalized.as_str() { "POINT" => GeometrySubtype::Point, "LINESTRING" => GeometrySubtype::LineString, "POLYGON" => GeometrySubtype::Polygon, @@ -1543,7 +1544,12 @@ fn map_geometry_subtype(pg_name: &str) -> GeometrySubtype { "MULTIPOLYGON" => GeometrySubtype::MultiPolygon, "GEOMETRYCOLLECTION" => GeometrySubtype::GeometryCollection, "GEOMETRY" => GeometrySubtype::Geometry, - _ => GeometrySubtype::Geometry, + other => { + tracing::warn!( + "Unknown PostGIS geometry subtype `{other}`; falling back to `GEOMETRY` (introspection will lose fidelity)." + ); + GeometrySubtype::Geometry + } } } @@ -1565,10 +1571,21 @@ fn parse_postgis_spatial(formatted_type: &str, spatial: PostgisSpatialKind) -> G if let Some(caps) = RE_TWO.captures(trimmed) { let subtype_str = caps.get(2).map(|m| m.as_str()).unwrap_or("GEOMETRY"); - let srid: i32 = caps.get(3).and_then(|m| m.as_str().parse().ok()).unwrap_or(0); + let srid_raw = caps.get(3).map(|m| m.as_str()); + let srid = match srid_raw.map(|s| s.parse::()) { + Some(Ok(v)) => Some(v), + Some(Err(err)) => { + tracing::warn!( + "Failed to parse PostGIS SRID `{}` for column type `{trimmed}`: {err}; treating as unspecified.", + srid_raw.unwrap_or_default(), + ); + None + } + None => None, + }; return GeometrySpec { subtype: map_geometry_subtype(subtype_str), - srid: Some(srid), + srid, spatial, }; } @@ -1582,6 +1599,9 @@ fn parse_postgis_spatial(formatted_type: &str, spatial: PostgisSpatialKind) -> G }; } + tracing::warn!( + "Unable to parse PostGIS spatial type `{trimmed}`; falling back to generic `GEOMETRY` with no SRID." + ); GeometrySpec { subtype: GeometrySubtype::Geometry, srid: None, @@ -1623,19 +1643,25 @@ fn get_column_type_family( if full_data_type == "geometry" && data_type == "USER-DEFINED" { let spec = parse_postgis_spatial(&row.get_expect_string("formatted_type"), PostgisSpatialKind::Geometry); - let sql = spec.postgres_sql_type(); + let args = GeometryNativeArgs { + subtype: spec.subtype, + srid: spec.srid, + }; return ( ColumnTypeFamily::Geometry(spec), - Some(PostgresType::Unknown(sql, Vec::new())), + Some(PostgresType::Postgis(PostgisNativeType::Geometry(args))), ); } if full_data_type == "geography" && data_type == "USER-DEFINED" { let spec = parse_postgis_spatial(&row.get_expect_string("formatted_type"), PostgisSpatialKind::Geography); - let sql = spec.postgres_sql_type(); + let args = GeometryNativeArgs { + subtype: spec.subtype, + srid: spec.srid, + }; return ( ColumnTypeFamily::Geometry(spec), - Some(PostgresType::Unknown(sql, Vec::new())), + Some(PostgresType::Postgis(PostgisNativeType::Geography(args))), ); } From 7e29da25afab6c4a7a3e7672a73fe17bd0fef81e Mon Sep 17 00:00:00 2001 From: Lam Hieu Date: Sat, 23 May 2026 15:50:07 +0700 Subject: [PATCH 6/6] test(postgis): match capitalized rendering in migration assertions and tolerate column-alignment padding in introspection checks --- .../tests/postgres/postgis_geometry.rs | 21 ++++++++++++------- .../migrations/postgres/postgis_geometry.rs | 6 +++--- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs b/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs index 546cbf25fda1..83212f1153ae 100644 --- a/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs +++ b/schema-engine/sql-introspection-tests/tests/postgres/postgis_geometry.rs @@ -21,9 +21,12 @@ async fn introspect_geometry_columns(api: &mut TestApi) -> TestResult { assert!(schema.contains("extensions = [postgis")); assert!(schema.contains("model locations")); - assert!(schema.contains("position Geometry?") && schema.contains("@db.Geometry(Point, 4326)")); - assert!(schema.contains("path") && schema.contains("@db.Geometry(LineString)")); - assert!(schema.contains("area Geometry") && schema.contains("@db.Geometry(Polygon, 3857)")); + // The introspector pretty-prints columns with alignment padding (`name Type @attr`), + // so we normalize whitespace before the substring check instead of pinning a single layout. + let normalized: String = schema.split_whitespace().collect::>().join(" "); + assert!(normalized.contains("position Geometry?") && normalized.contains("@db.Geometry(Point, 4326)")); + assert!(normalized.contains("path Geometry?") && normalized.contains("@db.Geometry(LineString)")); + assert!(normalized.contains("area Geometry") && normalized.contains("@db.Geometry(Polygon, 3857)")); Ok(()) } @@ -52,25 +55,29 @@ async fn introspect_geography_columns(api: &mut TestApi) -> TestResult { assert!(schema.contains("extensions = [postgis")); assert!(schema.contains("model places")); + // The renderer aligns column declarations with padding (`name Type @attr`); normalize + // whitespace so substring checks pin semantics, not formatting. + let normalized: String = schema.split_whitespace().collect::>().join(" "); + // Field type MUST be `Geography` (not `Geometry`) so the native attribute pairing validates. assert!( - schema.contains("location Geography?") && schema.contains("@db.Geography(Point, 4326)"), + normalized.contains("location Geography?") && normalized.contains("@db.Geography(Point, 4326)"), "expected `location Geography? ... @db.Geography(Point, 4326)`, got:\n{schema}", ); assert!( - schema.contains("region Geography") && schema.contains("@db.Geography(Polygon, 4326)"), + normalized.contains("region Geography") && normalized.contains("@db.Geography(Polygon, 4326)"), "expected `region Geography ... @db.Geography(Polygon, 4326)`, got:\n{schema}", ); // Untyped `geography` (no typmod) keeps the bare native form without subtype/SRID args. assert!( - schema.contains("footprint Geography?"), + normalized.contains("footprint Geography?"), "expected `footprint Geography?`, got:\n{schema}", ); // The introspector must NEVER emit `Geometry @db.Geography(...)` — that combo is rejected by // PSL validation and used to be the silent failure mode for `geography` columns. assert!( - !schema.contains("Geometry @db.Geography") && !schema.contains("Geometry? @db.Geography"), + !normalized.contains("Geometry @db.Geography") && !normalized.contains("Geometry? @db.Geography"), "introspected schema must not pair `Geometry` keyword with `@db.Geography`, got:\n{schema}", ); diff --git a/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs b/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs index ff3e26485597..39d4c2e9e4c1 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/postgres/postgis_geometry.rs @@ -20,7 +20,7 @@ fn create_table_with_geometry(api: TestApi) { let connector = psl::builtin_connectors::POSTGRES; api.assert_schema().assert_table("Location", |table| { table.assert_column("position", |col| { - col.assert_native_type("geometry(Point,4326)", connector) + col.assert_native_type("Geometry(Point,4326)", connector) }) }); } @@ -50,7 +50,7 @@ fn alter_geometry_srid(api: TestApi) { let connector = psl::builtin_connectors::POSTGRES; api.assert_schema().assert_table("Location", |table| { table.assert_column("position", |col| { - col.assert_native_type("geometry(Point,3857)", connector) + col.assert_native_type("Geometry(Point,3857)", connector) }) }); } @@ -106,7 +106,7 @@ fn create_table_with_geography(api: TestApi) { let connector = psl::builtin_connectors::POSTGRES; api.assert_schema().assert_table("Place", |table| { table.assert_column("region", |col| { - col.assert_native_type("geography(Polygon,4326)", connector) + col.assert_native_type("Geography(Polygon,4326)", connector) }) }); }