diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 0ccd01ad72..bc8a20ce10 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -33,8 +33,7 @@ runs: shell: bash run: | apt-get update - apt-get install -y protobuf-compiler - apt-get install -y clang + apt-get install -y protobuf-compiler clang cmake - name: Install JDK ${{inputs.jdk-version}} uses: actions/setup-java@v4 diff --git a/.github/actions/setup-macos-builder/action.yaml b/.github/actions/setup-macos-builder/action.yaml index 7c1c8b522e..6d8ab02b69 100644 --- a/.github/actions/setup-macos-builder/action.yaml +++ b/.github/actions/setup-macos-builder/action.yaml @@ -49,7 +49,7 @@ runs: unzip $PROTO_ZIP echo "$HOME/d/protoc/bin" >> $GITHUB_PATH export PATH=$PATH:$HOME/d/protoc/bin - # install openssl and setup DYLD_LIBRARY_PATH + # install openssl brew install openssl OPENSSL_LIB_PATH=`brew --prefix openssl`/lib echo "openssl lib path is: ${OPENSSL_LIB_PATH}" diff --git a/docs/geo-functions.md b/docs/geo-functions.md new file mode 100644 index 0000000000..1060745110 --- /dev/null +++ b/docs/geo-functions.md @@ -0,0 +1,602 @@ + + +# Comet Geo Functions + +Comet provides 40 geospatial SQL functions registered as Spark SQL extensions. +All functions execute natively in the Rust/DataFusion engine when Comet is enabled +(`spark.comet.exec.enabled=true`). Geometries are represented as WKT strings. + +## Constructors + +Functions that create geometry values. + +### st_geomfromwkt + +```sql +st_geomfromwkt(wkt STRING) -> STRING +``` + +Parses a WKT string and returns the geometry. Returns `null` if the input is `null`. + +```sql +SELECT st_geomfromwkt('POINT(1.0 2.0)'); +-- POINT (1 2) +``` + +### st_geomfromgeojson + +```sql +st_geomfromgeojson(geojson STRING) -> STRING +``` + +Parses a GeoJSON string and returns the geometry as WKT. + +```sql +SELECT st_geomfromgeojson('{"type":"Point","coordinates":[1.0,2.0]}'); +-- POINT (1 2) +``` + +### st_point + +```sql +st_point(x DOUBLE, y DOUBLE) -> STRING +``` + +Creates a point geometry from x (longitude) and y (latitude) coordinates. + +```sql +SELECT st_point(1.0, 2.0); +-- POINT(1.0 2.0) +``` + +### st_makeenvelope + +```sql +st_makeenvelope(xmin DOUBLE, ymin DOUBLE, xmax DOUBLE, ymax DOUBLE) -> STRING +``` + +Creates a rectangular polygon (envelope/bounding box) from corner coordinates. +Returns `null` if any argument is `null`. + +```sql +SELECT st_makeenvelope(0.0, 0.0, 1.0, 1.0); +-- POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0)) +``` + +### st_makeline + +```sql +st_makeline(geom1 STRING, geom2 STRING) -> STRING +``` + +Creates a linestring connecting two geometries. + +```sql +SELECT st_makeline(st_point(0.0, 0.0), st_point(1.0, 1.0)); +-- LINESTRING (0 0, 1 1) +``` + +--- + +## Serializers + +Functions that convert a geometry to a text format. + +### st_astext + +```sql +st_astext(geom STRING) -> STRING +``` + +Returns the WKT representation of a geometry. + +```sql +SELECT st_astext(st_point(1.0, 2.0)); +-- POINT (1 2) +``` + +### st_asgeojson + +```sql +st_asgeojson(geom STRING) -> STRING +``` + +Returns the GeoJSON representation of a geometry. + +```sql +SELECT st_asgeojson(st_point(1.0, 2.0)); +-- {"type":"Point","coordinates":[1.0,2.0]} +``` + +--- + +## Predicates + +Functions that test a spatial relationship between two geometries and return a boolean. + +### st_contains + +```sql +st_contains(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if `geom1` contains `geom2` (i.e. no point of `geom2` is outside `geom1`). + +```sql +SELECT st_contains( + st_makeenvelope(0.0, 0.0, 2.0, 2.0), + st_point(1.0, 1.0) +); +-- true +``` + +### st_intersects + +```sql +st_intersects(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if the two geometries share any point. + +```sql +SELECT st_intersects( + st_makeenvelope(0.0, 0.0, 2.0, 2.0), + st_makeenvelope(1.0, 1.0, 3.0, 3.0) +); +-- true +``` + +### st_within + +```sql +st_within(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if `geom1` is completely inside `geom2`. The inverse of `st_contains`. + +```sql +SELECT st_within( + st_point(1.0, 1.0), + st_makeenvelope(0.0, 0.0, 2.0, 2.0) +); +-- true +``` + +### st_covers + +```sql +st_covers(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if every point of `geom2` lies within or on the boundary of `geom1`. +Similar to `st_contains` but includes boundary points. + +### st_coveredby + +```sql +st_coveredby(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if every point of `geom1` lies within or on the boundary of `geom2`. +The inverse of `st_covers`. + +### st_equals + +```sql +st_equals(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if the two geometries represent the same geometric shape +(point-set equality, not string equality). + +### st_touches + +```sql +st_touches(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if the geometries share boundary points but their interiors do not intersect. + +### st_crosses + +```sql +st_crosses(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if the geometries have some but not all interior points in common, +and the dimension of the intersection is less than the maximum dimension of either geometry. + +### st_disjoint + +```sql +st_disjoint(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if the two geometries share no points. The inverse of `st_intersects`. + +### st_overlaps + +```sql +st_overlaps(geom1 STRING, geom2 STRING) -> BOOLEAN +``` + +Returns `true` if the two geometries of the same dimension intersect but neither contains the other. + +--- + +## Measurements + +Functions that compute a numeric value from one or two geometries. + +### st_area + +```sql +st_area(geom STRING) -> DOUBLE +``` + +Returns the area of a polygon geometry. Returns `0.0` for points and linestrings. + +```sql +SELECT st_area(st_makeenvelope(0.0, 0.0, 2.0, 3.0)); +-- 6.0 +``` + +### st_length + +```sql +st_length(geom STRING) -> DOUBLE +``` + +Returns the length of a linestring, or the perimeter of a polygon geometry. + +```sql +SELECT st_length(st_makeline(st_point(0.0, 0.0), st_point(3.0, 4.0))); +-- 5.0 +``` + +### st_perimeter + +```sql +st_perimeter(geom STRING) -> DOUBLE +``` + +Returns the perimeter of a polygon geometry. Returns `0.0` for points and linestrings. + +```sql +SELECT st_perimeter(st_makeenvelope(0.0, 0.0, 1.0, 1.0)); +-- 4.0 +``` + +### st_distance + +```sql +st_distance(geom1 STRING, geom2 STRING) -> DOUBLE +``` + +Returns the minimum planar (Cartesian) distance between two geometries. + +```sql +SELECT st_distance(st_point(0.0, 0.0), st_point(3.0, 4.0)); +-- 5.0 +``` + +### st_distancesphere + +```sql +st_distancesphere(geom1 STRING, geom2 STRING) -> DOUBLE +``` + +Returns the great-circle distance in metres between two point geometries, +assuming a spherical Earth model. + +```sql +SELECT st_distancesphere(st_point(-0.1276, 51.5074), st_point(2.3522, 48.8566)); +-- ~341550 (London to Paris, metres) +``` + +### st_hausdorffdistance + +```sql +st_hausdorffdistance(geom1 STRING, geom2 STRING) -> DOUBLE +``` + +Returns the Hausdorff distance between two geometries. Useful for measuring +how similar two shapes are. + +### st_numpoints + +```sql +st_numpoints(geom STRING) -> BIGINT +``` + +Returns the number of vertices in a geometry. + +```sql +SELECT st_numpoints(st_makeenvelope(0.0, 0.0, 1.0, 1.0)); +-- 5 +``` + +### st_x + +```sql +st_x(geom STRING) -> DOUBLE +``` + +Returns the x-coordinate (longitude) of a point geometry. + +```sql +SELECT st_x(st_point(1.5, 2.5)); +-- 1.5 +``` + +### st_y + +```sql +st_y(geom STRING) -> DOUBLE +``` + +Returns the y-coordinate (latitude) of a point geometry. + +```sql +SELECT st_y(st_point(1.5, 2.5)); +-- 2.5 +``` + +--- + +## Accessors + +Functions that return a property or derived geometry from a single geometry. + +### st_isempty + +```sql +st_isempty(geom STRING) -> BOOLEAN +``` + +Returns `true` if the geometry is empty (contains no points). + +### st_geometrytype + +```sql +st_geometrytype(geom STRING) -> STRING +``` + +Returns the type name of the geometry: `Point`, `LineString`, `Polygon`, +`MultiPoint`, `MultiLineString`, `MultiPolygon`, or `GeometryCollection`. + +```sql +SELECT st_geometrytype(st_point(1.0, 2.0)); +-- Point +``` + +--- + +## Transformations + +Functions that return a new geometry derived from the input. + +### st_centroid + +```sql +st_centroid(geom STRING) -> STRING +``` + +Returns the geometric centre (centroid) of a geometry as a point. + +```sql +SELECT st_centroid(st_makeenvelope(0.0, 0.0, 2.0, 2.0)); +-- POINT (1 1) +``` + +### st_envelope + +```sql +st_envelope(geom STRING) -> STRING +``` + +Returns the minimum bounding rectangle of a geometry as a polygon. + +```sql +SELECT st_envelope(st_makeline(st_point(1.0, 2.0), st_point(3.0, 4.0))); +-- POLYGON ((1 2, 3 2, 3 4, 1 4, 1 2)) +``` + +### st_convexhull + +```sql +st_convexhull(geom STRING) -> STRING +``` + +Returns the smallest convex polygon that contains all points of the geometry. + +### st_buffer + +```sql +st_buffer(geom STRING, distance DOUBLE) -> STRING +``` + +Returns a geometry that represents all points within `distance` of `geom`. +The result is a polygon for point and linestring inputs. + +```sql +SELECT st_buffer(st_point(0.0, 0.0), 1.0); +-- POLYGON ((1 0, ...)) (approximated circle) +``` + +### st_simplify + +```sql +st_simplify(geom STRING, tolerance DOUBLE) -> STRING +``` + +Simplifies a geometry using the Douglas-Peucker algorithm. Points within +`tolerance` of the simplified line are removed. May produce invalid topologies. + +```sql +SELECT st_simplify(st_makeenvelope(0.0, 0.0, 1.0, 1.0), 0.1); +``` + +### st_simplifypreservetopology + +```sql +st_simplifypreservetopology(geom STRING, tolerance DOUBLE) -> STRING +``` + +Same as `st_simplify` but guarantees the result is topologically valid +(no self-intersections, no collapse to empty). + +### st_flipcoordinates + +```sql +st_flipcoordinates(geom STRING) -> STRING +``` + +Swaps the x and y coordinates of every vertex. Useful for converting between +(longitude, latitude) and (latitude, longitude) conventions. + +```sql +SELECT st_flipcoordinates(st_point(1.0, 2.0)); +-- POINT (2 1) +``` + +### st_boundary + +```sql +st_boundary(geom STRING) -> STRING +``` + +Returns the boundary of a geometry. For a polygon this is its ring(s); +for a linestring it is its two endpoints; for a point it is empty. + +--- + +## Set Operations + +Functions that compute a new geometry from two input geometries. + +### st_union + +```sql +st_union(geom1 STRING, geom2 STRING) -> STRING +``` + +Returns a geometry representing all points in either `geom1` or `geom2`. + +```sql +SELECT st_union( + st_makeenvelope(0.0, 0.0, 1.0, 1.0), + st_makeenvelope(0.5, 0.5, 1.5, 1.5) +); +``` + +### st_intersection + +```sql +st_intersection(geom1 STRING, geom2 STRING) -> STRING +``` + +Returns a geometry representing the points shared by both `geom1` and `geom2`. + +```sql +SELECT st_intersection( + st_makeenvelope(0.0, 0.0, 2.0, 2.0), + st_makeenvelope(1.0, 1.0, 3.0, 3.0) +); +-- POLYGON ((1 1, 2 1, 2 2, 1 2, 1 1)) +``` + +### st_difference + +```sql +st_difference(geom1 STRING, geom2 STRING) -> STRING +``` + +Returns a geometry representing the points in `geom1` that are not in `geom2`. + +```sql +SELECT st_difference( + st_makeenvelope(0.0, 0.0, 2.0, 2.0), + st_makeenvelope(1.0, 1.0, 3.0, 3.0) +); +``` + +### st_symdifference + +```sql +st_symdifference(geom1 STRING, geom2 STRING) -> STRING +``` + +Returns a geometry representing the points in either `geom1` or `geom2` +but not both (the symmetric difference / XOR of the two shapes). + +```sql +SELECT st_symdifference( + st_makeenvelope(0.0, 0.0, 2.0, 2.0), + st_makeenvelope(1.0, 1.0, 3.0, 3.0) +); +``` + +--- + +## Function Summary + +| Function | Arguments | Return type | Category | +|---|---|---|---| +| `st_geomfromwkt` | wkt | STRING | Constructor | +| `st_geomfromgeojson` | geojson | STRING | Constructor | +| `st_point` | x, y | STRING | Constructor | +| `st_makeenvelope` | xmin, ymin, xmax, ymax | STRING | Constructor | +| `st_makeline` | geom1, geom2 | STRING | Constructor | +| `st_astext` | geom | STRING | Serializer | +| `st_asgeojson` | geom | STRING | Serializer | +| `st_contains` | geom1, geom2 | BOOLEAN | Predicate | +| `st_intersects` | geom1, geom2 | BOOLEAN | Predicate | +| `st_within` | geom1, geom2 | BOOLEAN | Predicate | +| `st_covers` | geom1, geom2 | BOOLEAN | Predicate | +| `st_coveredby` | geom1, geom2 | BOOLEAN | Predicate | +| `st_equals` | geom1, geom2 | BOOLEAN | Predicate | +| `st_touches` | geom1, geom2 | BOOLEAN | Predicate | +| `st_crosses` | geom1, geom2 | BOOLEAN | Predicate | +| `st_disjoint` | geom1, geom2 | BOOLEAN | Predicate | +| `st_overlaps` | geom1, geom2 | BOOLEAN | Predicate | +| `st_area` | geom | DOUBLE | Measurement | +| `st_length` | geom | DOUBLE | Measurement | +| `st_perimeter` | geom | DOUBLE | Measurement | +| `st_distance` | geom1, geom2 | DOUBLE | Measurement | +| `st_distancesphere` | geom1, geom2 | DOUBLE | Measurement | +| `st_hausdorffdistance` | geom1, geom2 | DOUBLE | Measurement | +| `st_numpoints` | geom | BIGINT | Measurement | +| `st_x` | geom | DOUBLE | Measurement | +| `st_y` | geom | DOUBLE | Measurement | +| `st_isempty` | geom | BOOLEAN | Accessor | +| `st_geometrytype` | geom | STRING | Accessor | +| `st_centroid` | geom | STRING | Transformation | +| `st_envelope` | geom | STRING | Transformation | +| `st_convexhull` | geom | STRING | Transformation | +| `st_buffer` | geom, distance | STRING | Transformation | +| `st_simplify` | geom, tolerance | STRING | Transformation | +| `st_simplifypreservetopology` | geom, tolerance | STRING | Transformation | +| `st_flipcoordinates` | geom | STRING | Transformation | +| `st_boundary` | geom | STRING | Transformation | +| `st_union` | geom1, geom2 | STRING | Set operation | +| `st_intersection` | geom1, geom2 | STRING | Set operation | +| `st_difference` | geom1, geom2 | STRING | Set operation | +| `st_symdifference` | geom1, geom2 | STRING | Set operation | diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index c58d446917..e6a024c5ea 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -77,6 +77,11 @@ iceberg = { workspace = true } iceberg-storage-opendal = { workspace = true } serde_json = "1.0" uuid = "1.23.0" +geo = "0.28" +geoarrow = "0.8" +geojson = { version = "0.24", features = ["geo-types"] } + +wkt = "0.11" [target.'cfg(target_os = "linux")'.dependencies] procfs = "0.18.0" diff --git a/native/core/src/execution/expressions/geo/mod.rs b/native/core/src/execution/expressions/geo/mod.rs new file mode 100644 index 0000000000..5b41b34a17 --- /dev/null +++ b/native/core/src/execution/expressions/geo/mod.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod st_area; +mod st_as_geojson; +mod st_as_text; +mod st_boundary; +mod st_buffer; +mod st_centroid; +mod st_contains; +mod st_convex_hull; +mod st_covered_by; +mod st_covers; +mod st_crosses; +mod st_difference; +mod st_disjoint; +mod st_distance; +mod st_distance_sphere; +mod st_envelope; +mod st_equals; +mod st_flip_coordinates; +mod st_geom_from_geojson; +mod st_geom_from_wkt; +mod st_geometry_type; +mod st_hausdorff_distance; +mod st_intersection; +mod st_intersects; +mod st_is_empty; +mod st_length; +mod st_make_envelope; +mod st_make_line; +mod st_num_points; +mod st_overlaps; +mod st_perimeter; +mod st_point; +mod st_simplify; +mod st_simplify_preserve_topology; +mod st_sym_difference; +mod st_touches; +mod st_union; +mod st_within; +mod st_x; +mod st_y; + +use datafusion::execution::context::SessionContext; +use datafusion::logical_expr::ScalarUDF; + +pub fn register_geo_functions(ctx: &SessionContext) { + // Constructors + ctx.register_udf(ScalarUDF::new_from_impl( + st_geom_from_wkt::StGeomFromWkt::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl( + st_geom_from_geojson::StGeomFromGeoJson::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_point::StPoint::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_make_envelope::StMakeEnvelope::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_make_line::StMakeLine::default())); + // Serializers + ctx.register_udf(ScalarUDF::new_from_impl(st_as_text::StAsText::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_as_geojson::StAsGeoJson::default(), + )); + // Predicates + ctx.register_udf(ScalarUDF::new_from_impl(st_contains::StContains::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_intersects::StIntersects::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_within::StWithin::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_covers::StCovers::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_covered_by::StCoveredBy::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_equals::StEquals::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_touches::StTouches::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_crosses::StCrosses::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_disjoint::StDisjoint::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_overlaps::StOverlaps::default())); + // Measurements + ctx.register_udf(ScalarUDF::new_from_impl(st_distance::StDistance::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_distance_sphere::StDistanceSphere::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_area::StArea::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_length::StLength::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_perimeter::StPerimeter::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl( + st_hausdorff_distance::StHausdorffDistance::default(), + )); + // Transformations + ctx.register_udf(ScalarUDF::new_from_impl(st_centroid::StCentroid::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_buffer::StBuffer::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_envelope::StEnvelope::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_convex_hull::StConvexHull::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_simplify::StSimplify::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_simplify_preserve_topology::StSimplifyPreserveTopology::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl( + st_flip_coordinates::StFlipCoordinates::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_boundary::StBoundary::default())); + // Set operations + ctx.register_udf(ScalarUDF::new_from_impl(st_union::StUnion::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_intersection::StIntersection::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl( + st_difference::StDifference::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl( + st_sym_difference::StSymDifference::default(), + )); + // Accessors + ctx.register_udf(ScalarUDF::new_from_impl(st_is_empty::StIsEmpty::default())); + ctx.register_udf(ScalarUDF::new_from_impl( + st_geometry_type::StGeometryType::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl( + st_num_points::StNumPoints::default(), + )); + ctx.register_udf(ScalarUDF::new_from_impl(st_x::StX::default())); + ctx.register_udf(ScalarUDF::new_from_impl(st_y::StY::default())); +} diff --git a/native/core/src/execution/expressions/geo/st_area.rs b/native/core/src/execution/expressions/geo/st_area.rs new file mode 100644 index 0000000000..fc34088f14 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_area.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::Area; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StArea { + signature: Signature, +} + +impl Default for StArea { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StArea { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_area" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + Some(geom.unsigned_area()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_as_geojson.rs b/native/core/src/execution/expressions/geo/st_as_geojson.rs new file mode 100644 index 0000000000..c15704acac --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_as_geojson.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geojson::Geometry as GeoJsonGeometry; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StAsGeoJson { + signature: Signature, +} + +impl Default for StAsGeoJson { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StAsGeoJson { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_asgeojson" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let col = arrays[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = col + .iter() + .map(|v| { + let wkt_str = v?; + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + let gj: GeoJsonGeometry = GeoJsonGeometry::from(&geom); + Some(gj.to_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_as_text.rs b/native/core/src/execution/expressions/geo/st_as_text.rs new file mode 100644 index 0000000000..539d775933 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_as_text.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StAsText { + signature: Signature, +} + +impl Default for StAsText { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StAsText { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_astext" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let col = arrays[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = col + .iter() + .map(|v| { + let wkt_str = v?; + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + Some(geom.wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_boundary.rs b/native/core/src/execution/expressions/geo/st_boundary.rs new file mode 100644 index 0000000000..9534be62d6 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_boundary.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StBoundary { + signature: Signature, +} + +impl Default for StBoundary { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StBoundary { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_boundary" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let col = arrays[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = col + .iter() + .map(|v| { + let wkt_str = v?; + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + // Boundary of a polygon = its exterior ring as a LineString + // Boundary of a LineString = its two endpoints as a MultiPoint + let boundary: geo::Geometry = match geom { + geo::Geometry::Polygon(p) => geo::Geometry::LineString(p.exterior().clone()), + geo::Geometry::MultiPolygon(mp) => { + let rings: Vec> = + mp.iter().map(|p| p.exterior().clone()).collect(); + geo::Geometry::MultiLineString(geo::MultiLineString(rings)) + } + geo::Geometry::LineString(ls) => { + let coords = ls.0.clone(); + if coords.len() < 2 { + return Some("GEOMETRYCOLLECTION EMPTY".to_string()); + } + let pts = vec![ + geo::Point::from(*coords.first()?), + geo::Point::from(*coords.last()?), + ]; + geo::Geometry::MultiPoint(geo::MultiPoint(pts)) + } + _ => return Some("GEOMETRYCOLLECTION EMPTY".to_string()), + }; + Some(boundary.wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_buffer.rs b/native/core/src/execution/expressions/geo/st_buffer.rs new file mode 100644 index 0000000000..18eed750aa --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_buffer.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::f64::consts::PI; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::scalar::ScalarValue; +use geo::{Coord, LineString, Point, Polygon}; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StBuffer { + signature: Signature, +} + +impl Default for StBuffer { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Float64], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StBuffer { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_buffer" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + // Extract distance — may be a scalar literal or a column. + let distance = scalar_to_f64(&args.args[1]); + let geom_arrays = ColumnarValue::values_to_arrays(std::slice::from_ref(&args.args[0]))?; + let geom_col = geom_arrays[0] + .as_any() + .downcast_ref::() + .unwrap(); + + let result: StringArray = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + Some(geom_to_wkt(&buffer_geometry(&geom, distance, 32))) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn scalar_to_f64(val: &ColumnarValue) -> f64 { + match val { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => *v, + ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(v), _p, s)) => { + (*v as f64) / 10f64.powi(*s as i32) + } + _ => 0.0, + } +} + +fn point_circle(cx: f64, cy: f64, radius: f64, segments: usize) -> Polygon { + let coords: Vec> = (0..=segments) + .map(|i| { + let angle = 2.0 * PI * (i as f64) / (segments as f64); + Coord { + x: cx + radius * angle.cos(), + y: cy + radius * angle.sin(), + } + }) + .collect(); + Polygon::new(LineString::from(coords), vec![]) +} + +fn coords_to_wkt(coords: &[Coord]) -> String { + let pts: Vec = coords.iter().map(|c| format!("{} {}", c.x, c.y)).collect(); + format!("({})", pts.join(",")) +} + +fn geom_to_wkt(geom: &geo::Geometry) -> String { + match geom { + geo::Geometry::Polygon(p) => { + format!("POLYGON({})", coords_to_wkt(p.exterior().0.as_slice())) + } + geo::Geometry::MultiPolygon(mp) => { + let parts: Vec = mp + .iter() + .map(|p| coords_to_wkt(p.exterior().0.as_slice())) + .collect(); + format!("MULTIPOLYGON(({}))", parts.join("),(")) + } + other => { + use wkt::ToWkt; + other.wkt_string() + } + } +} + +fn buffer_geometry( + geom: &geo::Geometry, + distance: f64, + segments: usize, +) -> geo::Geometry { + match geom { + geo::Geometry::Point(Point(c)) => { + geo::Geometry::Polygon(point_circle(c.x, c.y, distance, segments)) + } + geo::Geometry::MultiPoint(mp) => { + let polys: Vec> = mp + .iter() + .map(|Point(c)| point_circle(c.x, c.y, distance, segments)) + .collect(); + geo::Geometry::MultiPolygon(geo::MultiPolygon(polys)) + } + other => other.clone(), + } +} diff --git a/native/core/src/execution/expressions/geo/st_centroid.rs b/native/core/src/execution/expressions/geo/st_centroid.rs new file mode 100644 index 0000000000..20c6c39e6e --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_centroid.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::Centroid; +use wkt::ToWkt; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StCentroid { + signature: Signature, +} + +impl Default for StCentroid { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StCentroid { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_centroid" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = geom_col + .iter() + .map(|g| { + let wkt_str = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + let centroid: geo::Point = geom.centroid()?; + Some(centroid.to_wkt().to_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_contains.rs b/native/core/src/execution/expressions/geo/st_contains.rs new file mode 100644 index 0000000000..079c007b2c --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_contains.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StContains { + signature: Signature, +} + +impl Default for StContains { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StContains { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom1 = args[0].as_any().downcast_ref::().unwrap(); + let geom2 = args[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = geom1 + .iter() + .zip(geom2.iter()) + .map(|(g1, g2)| match (g1, g2) { + (Some(g1), Some(g2)) => { + let outer = geo::Geometry::::try_from_wkt_str(g1).ok()?; + let inner = geo::Geometry::::try_from_wkt_str(g2).ok()?; + // DE-9IM: T*****FF* — interior of inner intersects interior of outer, + // and inner has no part outside outer. Matches OGC/Sedona ST_Contains. + Some(outer.relate(&inner).is_contains()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_convex_hull.rs b/native/core/src/execution/expressions/geo/st_convex_hull.rs new file mode 100644 index 0000000000..05bc36b744 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_convex_hull.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::ConvexHull; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StConvexHull { + signature: Signature, +} + +impl Default for StConvexHull { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StConvexHull { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_convexhull" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + let hull = geom.convex_hull(); + Some(geo::Geometry::from(hull).wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_covered_by.rs b/native/core/src/execution/expressions/geo/st_covered_by.rs new file mode 100644 index 0000000000..acc090d983 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_covered_by.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StCoveredBy { + signature: Signature, +} + +impl Default for StCoveredBy { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StCoveredBy { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_coveredby" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| match (w1, w2) { + (Some(w1), Some(w2)) => { + let a = geo::Geometry::::try_from_wkt_str(w1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2).ok()?; + Some(a.relate(&b).is_coveredby()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_covers.rs b/native/core/src/execution/expressions/geo/st_covers.rs new file mode 100644 index 0000000000..8b229017cd --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_covers.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StCovers { + signature: Signature, +} + +impl Default for StCovers { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StCovers { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_covers" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| match (w1, w2) { + (Some(w1), Some(w2)) => { + let a = geo::Geometry::::try_from_wkt_str(w1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2).ok()?; + // DE-9IM T*****FF* or *T****FF* or ***T**FF* or ****T*FF* + Some(a.relate(&b).is_covers()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_crosses.rs b/native/core/src/execution/expressions/geo/st_crosses.rs new file mode 100644 index 0000000000..4d864ae0cc --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_crosses.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StCrosses { + signature: Signature, +} + +impl Default for StCrosses { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StCrosses { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_crosses" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| match (w1, w2) { + (Some(w1), Some(w2)) => { + let a = geo::Geometry::::try_from_wkt_str(w1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2).ok()?; + Some(a.relate(&b).is_crosses()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_difference.rs b/native/core/src/execution/expressions/geo/st_difference.rs new file mode 100644 index 0000000000..1b8c85266a --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_difference.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::BooleanOps; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StDifference { + signature: Signature, +} + +impl Default for StDifference { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StDifference { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_difference" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: StringArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| { + let a = as_multipolygon(w1?)?; + let b = as_multipolygon(w2?)?; + Some(a.difference(&b).wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn as_multipolygon(wkt: &str) -> Option> { + match geo::Geometry::::try_from_wkt_str(wkt).ok()? { + geo::Geometry::Polygon(p) => Some(geo::MultiPolygon(vec![p])), + geo::Geometry::MultiPolygon(mp) => Some(mp), + _ => None, + } +} diff --git a/native/core/src/execution/expressions/geo/st_disjoint.rs b/native/core/src/execution/expressions/geo/st_disjoint.rs new file mode 100644 index 0000000000..cef9f5d084 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_disjoint.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StDisjoint { + signature: Signature, +} + +impl Default for StDisjoint { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StDisjoint { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_disjoint" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| match (w1, w2) { + (Some(w1), Some(w2)) => { + let a = geo::Geometry::::try_from_wkt_str(w1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2).ok()?; + Some(a.relate(&b).is_disjoint()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_distance.rs b/native/core/src/execution/expressions/geo/st_distance.rs new file mode 100644 index 0000000000..2510e78da1 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_distance.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use geo::EuclideanDistance; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StDistance { + signature: Signature, +} + +impl Default for StDistance { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StDistance { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_distance" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom1 = args[0].as_any().downcast_ref::().unwrap(); + let geom2 = args[1].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = geom1 + .iter() + .zip(geom2.iter()) + .map(|(g1, g2)| match (g1, g2) { + (Some(g1), Some(g2)) => { + let a = geo::Geometry::::try_from_wkt_str(g1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(g2).ok()?; + // If geometries intersect, distance is 0.0 by definition. + // Otherwise use EuclideanDistance which correctly measures + // nearest-point distance between disjoint geometries. + if a.relate(&b).is_intersects() { + Some(0.0) + } else { + Some(a.euclidean_distance(&b)) + } + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_distance_sphere.rs b/native/core/src/execution/expressions/geo/st_distance_sphere.rs new file mode 100644 index 0000000000..85941c7594 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_distance_sphere.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::f64::consts::PI; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::TryFromWkt; + +const EARTH_RADIUS_METERS: f64 = 6_371_008.8; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StDistanceSphere { + signature: Signature, +} + +impl Default for StDistanceSphere { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StDistanceSphere { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_distancesphere" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| { + let a = geo::Geometry::::try_from_wkt_str(w1?).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2?).ok()?; + let (lon1, lat1) = centroid_coords(&a)?; + let (lon2, lat2) = centroid_coords(&b)?; + Some(haversine(lon1, lat1, lon2, lat2)) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn centroid_coords(geom: &geo::Geometry) -> Option<(f64, f64)> { + use geo::Centroid; + let c = geom.centroid()?; + Some((c.x(), c.y())) +} + +fn haversine(lon1: f64, lat1: f64, lon2: f64, lat2: f64) -> f64 { + let to_rad = PI / 180.0; + let dlat = (lat2 - lat1) * to_rad; + let dlon = (lon2 - lon1) * to_rad; + let a = (dlat / 2.0).sin().powi(2) + + lat1.to_radians().cos() * lat2.to_radians().cos() * (dlon / 2.0).sin().powi(2); + 2.0 * EARTH_RADIUS_METERS * a.sqrt().asin() +} diff --git a/native/core/src/execution/expressions/geo/st_envelope.rs b/native/core/src/execution/expressions/geo/st_envelope.rs new file mode 100644 index 0000000000..897245a5c1 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_envelope.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::BoundingRect; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StEnvelope { + signature: Signature, +} + +impl Default for StEnvelope { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StEnvelope { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_envelope" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + let rect = geom.bounding_rect()?; + Some(geo::Geometry::from(rect).wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_equals.rs b/native/core/src/execution/expressions/geo/st_equals.rs new file mode 100644 index 0000000000..af8d768644 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_equals.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StEquals { + signature: Signature, +} + +impl Default for StEquals { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StEquals { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_equals" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| match (w1, w2) { + (Some(w1), Some(w2)) => { + let a = geo::Geometry::::try_from_wkt_str(w1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2).ok()?; + // DE-9IM T*F**FFF* + Some(a.relate(&b).is_equal_topo()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_flip_coordinates.rs b/native/core/src/execution/expressions/geo/st_flip_coordinates.rs new file mode 100644 index 0000000000..63435ea835 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_flip_coordinates.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::MapCoords; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StFlipCoordinates { + signature: Signature, +} + +impl Default for StFlipCoordinates { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StFlipCoordinates { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_flipcoordinates" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let col = arrays[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = col + .iter() + .map(|v| { + let wkt_str = v?; + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + let flipped = geom.map_coords(|geo::Coord { x, y }| geo::Coord { x: y, y: x }); + Some(flipped.wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_geom_from_geojson.rs b/native/core/src/execution/expressions/geo/st_geom_from_geojson.rs new file mode 100644 index 0000000000..adc0cb7c6e --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_geom_from_geojson.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::ToWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StGeomFromGeoJson { + signature: Signature, +} + +impl Default for StGeomFromGeoJson { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StGeomFromGeoJson { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_geomfromgeojson" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let col = arrays[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = col + .iter() + .map(|v| { + let json_str = v?; + let gj: geojson::Geometry = json_str.parse().ok()?; + let geom: geo::Geometry = geo::Geometry::try_from(&gj).ok()?; + Some(geom.wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_geom_from_wkt.rs b/native/core/src/execution/expressions/geo/st_geom_from_wkt.rs new file mode 100644 index 0000000000..981db21e9d --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_geom_from_wkt.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StGeomFromWkt { + signature: Signature, +} + +impl Default for StGeomFromWkt { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StGeomFromWkt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_geomfromwkt" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let col = arrays[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = col + .iter() + .map(|v| { + let wkt_str = v?; + // Validate by parsing then re-serialising to normalised WKT + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + use wkt::ToWkt; + Some(geom.wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_geometry_type.rs b/native/core/src/execution/expressions/geo/st_geometry_type.rs new file mode 100644 index 0000000000..b512cf1542 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_geometry_type.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StGeometryType { + signature: Signature, +} + +impl Default for StGeometryType { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StGeometryType { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_geometrytype" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: StringArray = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + let type_name = match geom { + geo::Geometry::Point(_) => "ST_Point", + geo::Geometry::Line(_) => "ST_LineString", + geo::Geometry::LineString(_) => "ST_LineString", + geo::Geometry::Polygon(_) => "ST_Polygon", + geo::Geometry::MultiPoint(_) => "ST_MultiPoint", + geo::Geometry::MultiLineString(_) => "ST_MultiLineString", + geo::Geometry::MultiPolygon(_) => "ST_MultiPolygon", + geo::Geometry::GeometryCollection(_) => "ST_GeometryCollection", + geo::Geometry::Rect(_) => "ST_Polygon", + geo::Geometry::Triangle(_) => "ST_Polygon", + }; + Some(type_name.to_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_hausdorff_distance.rs b/native/core/src/execution/expressions/geo/st_hausdorff_distance.rs new file mode 100644 index 0000000000..7357c61f90 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_hausdorff_distance.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::HausdorffDistance; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StHausdorffDistance { + signature: Signature, +} + +impl Default for StHausdorffDistance { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StHausdorffDistance { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_hausdorffdistance" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| { + let a = geo::Geometry::::try_from_wkt_str(w1?).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2?).ok()?; + Some(a.hausdorff_distance(&b)) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_intersection.rs b/native/core/src/execution/expressions/geo/st_intersection.rs new file mode 100644 index 0000000000..45165f0c42 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_intersection.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::BooleanOps; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StIntersection { + signature: Signature, +} + +impl Default for StIntersection { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StIntersection { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_intersection" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: StringArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| { + let a = as_multipolygon(w1?)?; + let b = as_multipolygon(w2?)?; + Some(a.intersection(&b).wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn as_multipolygon(wkt: &str) -> Option> { + match geo::Geometry::::try_from_wkt_str(wkt).ok()? { + geo::Geometry::Polygon(p) => Some(geo::MultiPolygon(vec![p])), + geo::Geometry::MultiPolygon(mp) => Some(mp), + _ => None, + } +} diff --git a/native/core/src/execution/expressions/geo/st_intersects.rs b/native/core/src/execution/expressions/geo/st_intersects.rs new file mode 100644 index 0000000000..4ba2b98d6a --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_intersects.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StIntersects { + signature: Signature, +} + +impl Default for StIntersects { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StIntersects { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_intersects" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom1 = args[0].as_any().downcast_ref::().unwrap(); + let geom2 = args[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = geom1 + .iter() + .zip(geom2.iter()) + .map(|(g1, g2)| match (g1, g2) { + (Some(g1), Some(g2)) => { + let a = geo::Geometry::::try_from_wkt_str(g1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(g2).ok()?; + // DE-9IM: NOT Disjoint — any part of a shares space with any part of b. + Some(!a.relate(&b).is_disjoint()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_is_empty.rs b/native/core/src/execution/expressions/geo/st_is_empty.rs new file mode 100644 index 0000000000..30f804e301 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_is_empty.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::{CoordsIter, HasDimensions}; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StIsEmpty { + signature: Signature, +} + +impl Default for StIsEmpty { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StIsEmpty { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_isempty" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + Some(match geom { + geo::Geometry::GeometryCollection(gc) => gc.is_empty(), + geo::Geometry::MultiPoint(mp) => mp.is_empty(), + geo::Geometry::MultiLineString(ml) => ml.is_empty(), + geo::Geometry::MultiPolygon(mp) => mp.is_empty(), + geo::Geometry::LineString(ls) => ls.coords_count() == 0, + _ => false, + }) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_length.rs b/native/core/src/execution/expressions/geo/st_length.rs new file mode 100644 index 0000000000..346266e960 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_length.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::EuclideanLength; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StLength { + signature: Signature, +} + +impl Default for StLength { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StLength { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + let len = match geom { + geo::Geometry::LineString(ls) => ls.euclidean_length(), + geo::Geometry::MultiLineString(ml) => ml.euclidean_length(), + geo::Geometry::Polygon(p) => p.exterior().euclidean_length(), + geo::Geometry::MultiPolygon(mp) => { + mp.iter().map(|p| p.exterior().euclidean_length()).sum() + } + _ => 0.0, + }; + Some(len) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_make_envelope.rs b/native/core/src/execution/expressions/geo/st_make_envelope.rs new file mode 100644 index 0000000000..666d60cd97 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_make_envelope.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Decimal128Array, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StMakeEnvelope { + signature: Signature, +} + +impl Default for StMakeEnvelope { + fn default() -> Self { + Self { + signature: Signature::any(4, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StMakeEnvelope { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_makeenvelope" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let xmins = extract_f64_col(&arrays[0]); + let ymins = extract_f64_col(&arrays[1]); + let xmaxs = extract_f64_col(&arrays[2]); + let ymaxs = extract_f64_col(&arrays[3]); + + let result: StringArray = (0..xmins.len()) + .map(|i| { + let (xmin, ymin, xmax, ymax) = (xmins[i]?, ymins[i]?, xmaxs[i]?, ymaxs[i]?); + Some(format!( + "POLYGON(({xmin} {ymin},{xmax} {ymin},{xmax} {ymax},{xmin} {ymax},{xmin} {ymin}))" + )) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn extract_f64_col(arr: &dyn Array) -> Vec> { + if let Some(a) = arr.as_any().downcast_ref::() { + return a.iter().collect(); + } + if let Some(a) = arr.as_any().downcast_ref::() { + let scale = match arr.data_type() { + DataType::Decimal128(_, s) => *s as i32, + _ => 0, + }; + return a + .iter() + .map(|v| v.map(|n| (n as f64) / 10f64.powi(scale))) + .collect(); + } + vec![None; arr.len()] +} diff --git a/native/core/src/execution/expressions/geo/st_make_line.rs b/native/core/src/execution/expressions/geo/st_make_line.rs new file mode 100644 index 0000000000..70c61fd923 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_make_line.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StMakeLine { + signature: Signature, +} + +impl Default for StMakeLine { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StMakeLine { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_makeline" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let p1s = arrays[0].as_any().downcast_ref::().unwrap(); + let p2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: StringArray = p1s + .iter() + .zip(p2s.iter()) + .map(|(w1, w2)| { + let g1 = geo::Geometry::::try_from_wkt_str(w1?).ok()?; + let g2 = geo::Geometry::::try_from_wkt_str(w2?).ok()?; + let c1 = match g1 { + geo::Geometry::Point(p) => p.0, + _ => return None, + }; + let c2 = match g2 { + geo::Geometry::Point(p) => p.0, + _ => return None, + }; + Some(format!("LINESTRING({} {},{} {})", c1.x, c1.y, c2.x, c2.y)) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_num_points.rs b/native/core/src/execution/expressions/geo/st_num_points.rs new file mode 100644 index 0000000000..d4a8920e93 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_num_points.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::CoordsIter; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StNumPoints { + signature: Signature, +} + +impl Default for StNumPoints { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StNumPoints { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_numpoints" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: Int64Array = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + Some(geom.coords_count() as i64) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_overlaps.rs b/native/core/src/execution/expressions/geo/st_overlaps.rs new file mode 100644 index 0000000000..73399d08d0 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_overlaps.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StOverlaps { + signature: Signature, +} + +impl Default for StOverlaps { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StOverlaps { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_overlaps" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| match (w1, w2) { + (Some(w1), Some(w2)) => { + let a = geo::Geometry::::try_from_wkt_str(w1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2).ok()?; + Some(a.relate(&b).is_overlaps()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_perimeter.rs b/native/core/src/execution/expressions/geo/st_perimeter.rs new file mode 100644 index 0000000000..63965f3687 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_perimeter.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::EuclideanLength; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StPerimeter { + signature: Signature, +} + +impl Default for StPerimeter { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StPerimeter { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_perimeter" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let col = arrays[0].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = col + .iter() + .map(|v| { + let wkt_str = v?; + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + let len = match geom { + geo::Geometry::Polygon(p) => { + p.exterior().euclidean_length() + + p.interiors() + .iter() + .map(|r| r.euclidean_length()) + .sum::() + } + geo::Geometry::MultiPolygon(mp) => mp + .iter() + .map(|p| { + p.exterior().euclidean_length() + + p.interiors() + .iter() + .map(|r| r.euclidean_length()) + .sum::() + }) + .sum(), + _ => 0.0, + }; + Some(len) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_point.rs b/native/core/src/execution/expressions/geo/st_point.rs new file mode 100644 index 0000000000..e1715b2082 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_point.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Decimal128Array, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StPoint { + signature: Signature, +} + +impl Default for StPoint { + fn default() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StPoint { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_point" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let xs = extract_f64_col(&arrays[0]); + let ys = extract_f64_col(&arrays[1]); + + let result: StringArray = xs + .iter() + .zip(ys.iter()) + .map(|(x, y)| { + let (x, y) = ((*x)?, (*y)?); + Some(format!("POINT({} {})", x, y)) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn extract_f64_col(arr: &dyn Array) -> Vec> { + if let Some(a) = arr.as_any().downcast_ref::() { + return a.iter().collect(); + } + if let Some(a) = arr.as_any().downcast_ref::() { + let scale = match arr.data_type() { + DataType::Decimal128(_, s) => *s as i32, + _ => 0, + }; + return a + .iter() + .map(|v| v.map(|n| (n as f64) / 10f64.powi(scale))) + .collect(); + } + vec![None; arr.len()] +} diff --git a/native/core/src/execution/expressions/geo/st_simplify.rs b/native/core/src/execution/expressions/geo/st_simplify.rs new file mode 100644 index 0000000000..06d68ba165 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_simplify.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::scalar::ScalarValue; +use geo::Simplify; +use wkt::{ToWkt, TryFromWkt}; + +fn scalar_to_f64(val: &ColumnarValue) -> f64 { + match val { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => *v, + ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(v), _p, s)) => { + (*v as f64) / 10f64.powi(*s as i32) + } + _ => 0.0, + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StSimplify { + signature: Signature, +} + +impl Default for StSimplify { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Float64], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StSimplify { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_simplify" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let tolerance = scalar_to_f64(&args.args[1]); + let geom_arrays = ColumnarValue::values_to_arrays(std::slice::from_ref(&args.args[0]))?; + let geom_col = geom_arrays[0] + .as_any() + .downcast_ref::() + .unwrap(); + + let result: StringArray = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + let simplified = match geom { + geo::Geometry::LineString(ls) => { + geo::Geometry::LineString(ls.simplify(&tolerance)) + } + geo::Geometry::MultiLineString(ml) => { + geo::Geometry::MultiLineString(ml.simplify(&tolerance)) + } + geo::Geometry::Polygon(p) => geo::Geometry::Polygon(p.simplify(&tolerance)), + geo::Geometry::MultiPolygon(mp) => { + geo::Geometry::MultiPolygon(mp.simplify(&tolerance)) + } + other => other, + }; + Some(simplified.wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_simplify_preserve_topology.rs b/native/core/src/execution/expressions/geo/st_simplify_preserve_topology.rs new file mode 100644 index 0000000000..549b750e2e --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_simplify_preserve_topology.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::scalar::ScalarValue; +use geo::SimplifyVwPreserve; +use wkt::{ToWkt, TryFromWkt}; + +fn scalar_to_f64(val: &ColumnarValue) -> f64 { + match val { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => *v, + ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => *v as f64, + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(v), _p, s)) => { + (*v as f64) / 10f64.powi(*s as i32) + } + _ => 0.0, + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StSimplifyPreserveTopology { + signature: Signature, +} + +impl Default for StSimplifyPreserveTopology { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Float64], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StSimplifyPreserveTopology { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_simplifypreservetopology" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let tolerance = scalar_to_f64(&args.args[1]); + let geom_arrays = ColumnarValue::values_to_arrays(std::slice::from_ref(&args.args[0]))?; + let col = geom_arrays[0] + .as_any() + .downcast_ref::() + .unwrap(); + + let result: StringArray = col + .iter() + .map(|v| { + let wkt_str = v?; + let geom = geo::Geometry::::try_from_wkt_str(wkt_str).ok()?; + let simplified = match geom { + geo::Geometry::LineString(ls) => { + geo::Geometry::LineString(ls.simplify_vw_preserve(&tolerance)) + } + geo::Geometry::MultiLineString(ml) => { + geo::Geometry::MultiLineString(ml.simplify_vw_preserve(&tolerance)) + } + geo::Geometry::Polygon(p) => { + geo::Geometry::Polygon(p.simplify_vw_preserve(&tolerance)) + } + geo::Geometry::MultiPolygon(mp) => { + geo::Geometry::MultiPolygon(mp.simplify_vw_preserve(&tolerance)) + } + other => other, + }; + Some(simplified.wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_sym_difference.rs b/native/core/src/execution/expressions/geo/st_sym_difference.rs new file mode 100644 index 0000000000..52182ae784 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_sym_difference.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::BooleanOps; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StSymDifference { + signature: Signature, +} + +impl Default for StSymDifference { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StSymDifference { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_symdifference" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: StringArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| { + let a = as_multipolygon(w1?)?; + let b = as_multipolygon(w2?)?; + Some(a.xor(&b).wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn as_multipolygon(wkt: &str) -> Option> { + match geo::Geometry::::try_from_wkt_str(wkt).ok()? { + geo::Geometry::Polygon(p) => Some(geo::MultiPolygon(vec![p])), + geo::Geometry::MultiPolygon(mp) => Some(mp), + _ => None, + } +} diff --git a/native/core/src/execution/expressions/geo/st_touches.rs b/native/core/src/execution/expressions/geo/st_touches.rs new file mode 100644 index 0000000000..6e1b4cedb9 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_touches.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StTouches { + signature: Signature, +} + +impl Default for StTouches { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StTouches { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_touches" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| match (w1, w2) { + (Some(w1), Some(w2)) => { + let a = geo::Geometry::::try_from_wkt_str(w1).ok()?; + let b = geo::Geometry::::try_from_wkt_str(w2).ok()?; + Some(a.relate(&b).is_touches()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_union.rs b/native/core/src/execution/expressions/geo/st_union.rs new file mode 100644 index 0000000000..90a4c65543 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_union.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::BooleanOps; +use wkt::{ToWkt, TryFromWkt}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StUnion { + signature: Signature, +} + +impl Default for StUnion { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StUnion { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_union" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + let g1s = arrays[0].as_any().downcast_ref::().unwrap(); + let g2s = arrays[1].as_any().downcast_ref::().unwrap(); + + let result: StringArray = g1s + .iter() + .zip(g2s.iter()) + .map(|(w1, w2)| { + let a = as_multipolygon(w1?)?; + let b = as_multipolygon(w2?)?; + Some(a.union(&b).wkt_string()) + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} + +fn as_multipolygon(wkt: &str) -> Option> { + match geo::Geometry::::try_from_wkt_str(wkt).ok()? { + geo::Geometry::Polygon(p) => Some(geo::MultiPolygon(vec![p])), + geo::Geometry::MultiPolygon(mp) => Some(mp), + _ => None, + } +} diff --git a/native/core/src/execution/expressions/geo/st_within.rs b/native/core/src/execution/expressions/geo/st_within.rs new file mode 100644 index 0000000000..0f328a01c1 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_within.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use geo::relate::Relate; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StWithin { + signature: Signature, +} + +impl Default for StWithin { + fn default() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StWithin { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_within" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom1 = args[0].as_any().downcast_ref::().unwrap(); + let geom2 = args[1].as_any().downcast_ref::().unwrap(); + + let result: BooleanArray = geom1 + .iter() + .zip(geom2.iter()) + .map(|(g1, g2)| match (g1, g2) { + (Some(g1), Some(g2)) => { + let inner = geo::Geometry::::try_from_wkt_str(g1).ok()?; + let outer = geo::Geometry::::try_from_wkt_str(g2).ok()?; + // DE-9IM: TF*FF**** — inner's interior intersects outer's interior, + // and inner has no part outside outer. Matches OGC/Sedona ST_Within. + Some(inner.relate(&outer).is_within()) + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_x.rs b/native/core/src/execution/expressions/geo/st_x.rs new file mode 100644 index 0000000000..baac9a7d17 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_x.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StX { + signature: Signature, +} + +impl Default for StX { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StX { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_x" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + match geom { + geo::Geometry::Point(p) => Some(p.x()), + _ => None, + } + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/geo/st_y.rs b/native/core/src/execution/expressions/geo/st_y.rs new file mode 100644 index 0000000000..712e1cc264 --- /dev/null +++ b/native/core/src/execution/expressions/geo/st_y.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion::common::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use wkt::TryFromWkt; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct StY { + signature: Signature, +} + +impl Default for StY { + fn default() -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for StY { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "st_y" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let geom_col = args[0].as_any().downcast_ref::().unwrap(); + + let result: Float64Array = geom_col + .iter() + .map(|g| { + let wkt = g?; + let geom = geo::Geometry::::try_from_wkt_str(wkt).ok()?; + match geom { + geo::Geometry::Point(p) => Some(p.y()), + _ => None, + } + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } +} diff --git a/native/core/src/execution/expressions/mod.rs b/native/core/src/execution/expressions/mod.rs index e174bd3747..a96e8ac780 100644 --- a/native/core/src/execution/expressions/mod.rs +++ b/native/core/src/execution/expressions/mod.rs @@ -20,6 +20,7 @@ pub mod arithmetic; pub mod bitwise; pub mod comparison; +pub mod geo; pub mod list_positions; pub mod logical; pub mod nullcheck; diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 97e3f851c5..f6fd5be1f7 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -562,6 +562,7 @@ fn prepare_datafusion_session_context( datafusion::functions_nested::register_all(&mut session_ctx)?; register_datafusion_spark_function(&session_ctx); + crate::execution::expressions::geo::register_geo_functions(&session_ctx); // Must be the last one to override existing functions with the same name datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 1ae90e1845..1d43549909 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf._ +import org.apache.comet.expressions.GeoExpressions import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions} import org.apache.comet.shims.ShimCometSparkSessionExtensions @@ -87,6 +88,46 @@ class CometSparkSessionExtensions with Logging with ShimCometSparkSessionExtensions { override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectFunction(GeoExpressions.stContainsInfo) + extensions.injectFunction(GeoExpressions.stIntersectsInfo) + extensions.injectFunction(GeoExpressions.stWithinInfo) + extensions.injectFunction(GeoExpressions.stDistanceInfo) + extensions.injectFunction(GeoExpressions.stAreaInfo) + extensions.injectFunction(GeoExpressions.stCentroidInfo) + extensions.injectFunction(GeoExpressions.stLengthInfo) + extensions.injectFunction(GeoExpressions.stIsEmptyInfo) + extensions.injectFunction(GeoExpressions.stGeometryTypeInfo) + extensions.injectFunction(GeoExpressions.stNumPointsInfo) + extensions.injectFunction(GeoExpressions.stXInfo) + extensions.injectFunction(GeoExpressions.stYInfo) + extensions.injectFunction(GeoExpressions.stEnvelopeInfo) + extensions.injectFunction(GeoExpressions.stConvexHullInfo) + extensions.injectFunction(GeoExpressions.stSimplifyInfo) + extensions.injectFunction(GeoExpressions.stBufferInfo) + extensions.injectFunction(GeoExpressions.stUnionInfo) + extensions.injectFunction(GeoExpressions.stIntersectionInfo) + extensions.injectFunction(GeoExpressions.stGeomFromWktInfo) + extensions.injectFunction(GeoExpressions.stGeomFromGeoJsonInfo) + extensions.injectFunction(GeoExpressions.stPointInfo) + extensions.injectFunction(GeoExpressions.stMakeEnvelopeInfo) + extensions.injectFunction(GeoExpressions.stMakeLineInfo) + extensions.injectFunction(GeoExpressions.stAsTextInfo) + extensions.injectFunction(GeoExpressions.stAsGeoJsonInfo) + extensions.injectFunction(GeoExpressions.stCoversInfo) + extensions.injectFunction(GeoExpressions.stCoveredByInfo) + extensions.injectFunction(GeoExpressions.stEqualsInfo) + extensions.injectFunction(GeoExpressions.stTouchesInfo) + extensions.injectFunction(GeoExpressions.stCrossesInfo) + extensions.injectFunction(GeoExpressions.stDisjointInfo) + extensions.injectFunction(GeoExpressions.stOverlapsInfo) + extensions.injectFunction(GeoExpressions.stDistanceSphereInfo) + extensions.injectFunction(GeoExpressions.stPerimeterInfo) + extensions.injectFunction(GeoExpressions.stHausdorffDistanceInfo) + extensions.injectFunction(GeoExpressions.stSimplifyPreserveTopologyInfo) + extensions.injectFunction(GeoExpressions.stFlipCoordinatesInfo) + extensions.injectFunction(GeoExpressions.stBoundaryInfo) + extensions.injectFunction(GeoExpressions.stDifferenceInfo) + extensions.injectFunction(GeoExpressions.stSymDifferenceInfo) extensions.injectColumnar { session => CometScanColumnar(session) } extensions.injectColumnar { session => CometExecColumnar(session) } // Pre-3.5 only: tag AQE DPP regions so the conversion rules below leave them Spark-native. @@ -109,6 +150,7 @@ class CometSparkSessionExtensions override def postColumnarTransitions: Rule[SparkPlan] = EliminateRedundantTransitions(session) } + } object CometSparkSessionExtensions extends Logging { diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometGeoFallback.scala b/spark/src/main/scala/org/apache/comet/expressions/CometGeoFallback.scala new file mode 100644 index 0000000000..9499e9dc5b --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/expressions/CometGeoFallback.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.expressions + +/** + * JVM fallback implementations for Comet geo UDFs. Called only when Comet native execution is + * disabled. When Comet is active the expression is serded to ScalarFunc and executed via the Rust + * geo crate in DataFusion. + * + * Requires Apache Sedona on the classpath to function. Without Sedona and without Comet enabled + * an UnsupportedOperationException is thrown at runtime. + */ +object CometGeoFallback { + + private def notSupported(fn: String): Nothing = + throw new UnsupportedOperationException( + s"$fn requires either Comet native execution (spark.comet.exec.enabled=true) " + + s"or Apache Sedona on the classpath for JVM fallback.") + + // Constructors + def geomFromWkt(g: String): String = notSupported("st_geomfromwkt") + def geomFromGeoJson(g: String): String = notSupported("st_geomfromgeojson") + def makeEnvelope(xmin: Double, ymin: Double, xmax: Double, ymax: Double): String = + notSupported("st_makeenvelope") + def makePoint(x: String, y: String): String = "POINT(" + x + " " + y + ")" + def makeLine(g1: String, g2: String): String = notSupported("st_makeline") + // Serializers + def asText(g: String): String = notSupported("st_astext") + def asGeoJson(g: String): String = notSupported("st_asgeojson") + // Predicates + def contains(g1: String, g2: String): Boolean = notSupported("st_contains") + def intersects(g1: String, g2: String): Boolean = notSupported("st_intersects") + def within(g1: String, g2: String): Boolean = notSupported("st_within") + def covers(g1: String, g2: String): Boolean = notSupported("st_covers") + def coveredBy(g1: String, g2: String): Boolean = notSupported("st_coveredby") + def equals(g1: String, g2: String): Boolean = notSupported("st_equals") + def touches(g1: String, g2: String): Boolean = notSupported("st_touches") + def crosses(g1: String, g2: String): Boolean = notSupported("st_crosses") + def disjoint(g1: String, g2: String): Boolean = notSupported("st_disjoint") + def overlaps(g1: String, g2: String): Boolean = notSupported("st_overlaps") + // Measurements + def distance(g1: String, g2: String): Double = notSupported("st_distance") + def distanceSphere(g1: String, g2: String): Double = notSupported("st_distancesphere") + def area(g: String): Double = notSupported("st_area") + def length(g: String): Double = notSupported("st_length") + def perimeter(g: String): Double = notSupported("st_perimeter") + def hausdorffDistance(g1: String, g2: String): Double = notSupported("st_hausdorffdistance") + // Transformations + def centroid(g: String): String = notSupported("st_centroid") + def envelope(g: String): String = notSupported("st_envelope") + def convexHull(g: String): String = notSupported("st_convexhull") + def simplify(g: String, tolerance: Double): String = notSupported("st_simplify") + def simplifyPreserveTopology(g: String, tolerance: Double): String = + notSupported("st_simplifypreservetopology") + def flipCoordinates(g: String): String = notSupported("st_flipcoordinates") + def boundary(g: String): String = notSupported("st_boundary") + def buffer(g: String, distance: Double): String = notSupported("st_buffer") + // Set operations + def union(g1: String, g2: String): String = notSupported("st_union") + def intersection(g1: String, g2: String): String = notSupported("st_intersection") + def difference(g1: String, g2: String): String = notSupported("st_difference") + def symDifference(g1: String, g2: String): String = notSupported("st_symdifference") + // Accessors + def isEmpty(g: String): Boolean = notSupported("st_isempty") + def geometryType(g: String): String = notSupported("st_geometrytype") + def numPoints(g: String): Long = notSupported("st_numpoints") + def stX(g: String): Double = notSupported("st_x") + def stY(g: String): Double = notSupported("st_y") +} diff --git a/spark/src/main/scala/org/apache/comet/expressions/GeoExpressions.scala b/spark/src/main/scala/org/apache/comet/expressions/GeoExpressions.scala new file mode 100644 index 0000000000..7a75b189e0 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/expressions/GeoExpressions.scala @@ -0,0 +1,870 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.expressions + +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionInfo, NullIntolerant, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, IntegerType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +// ---- Binary geo predicates ----------------------------------------------- + +case class StContains(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.contains(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".contains($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StIntersects(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.intersects(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".intersects($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StWithin(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.within(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".within($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StDistance(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.distance(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".distance($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StUnion(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g1: Any, g2: Any): Any = + UTF8String.fromString(CometGeoFallback.union(g1.toString, g2.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".union($g1.toString(), $g2.toString()))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StIntersection(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g1: Any, g2: Any): Any = + UTF8String.fromString(CometGeoFallback.intersection(g1.toString, g2.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".intersection($g1.toString(), $g2.toString()))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StPoint(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g1: Any, g2: Any): Any = + UTF8String.fromString(CometGeoFallback.makePoint(g1.toString, g2.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".makePoint($g1.toString(), $g2.toString()))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +// ---- Unary geo functions -------------------------------------------------- + +case class StArea(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g: Any): Any = CometGeoFallback.area(g.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.area($g.toString())") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StCentroid(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.centroid(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.centroid($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StLength(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g: Any): Any = CometGeoFallback.length(g.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.length($g.toString())") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StIsEmpty(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g: Any): Any = CometGeoFallback.isEmpty(g.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.isEmpty($g.toString())") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StGeometryType(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.geometryType(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".geometryType($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StNumPoints(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = LongType + override def nullSafeEval(g: Any): Any = CometGeoFallback.numPoints(g.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.numPoints($g.toString())") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StX(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g: Any): Any = CometGeoFallback.stX(g.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.stX($g.toString())") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StY(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g: Any): Any = CometGeoFallback.stY(g.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.stY($g.toString())") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StEnvelope(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.envelope(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.envelope($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StConvexHull(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.convexHull(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".convexHull($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +// st_simplify and st_buffer take two args (geom + numeric param) + +case class StSimplify(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any, t: Any): Any = + UTF8String.fromString(CometGeoFallback.simplify(g.toString, t.toString.toDouble)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g, t) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".simplify($g.toString(), Double.parseDouble($t.toString())))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StBuffer(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any, d: Any): Any = + UTF8String.fromString(CometGeoFallback.buffer(g.toString, d.toString.toDouble)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g, d) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".buffer($g.toString(), Double.parseDouble($d.toString())))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +// ---- Constructors -------------------------------------------------------- + +case class StGeomFromWkt(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.geomFromWkt(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.geomFromWkt($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StGeomFromGeoJson(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.geomFromGeoJson(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".geomFromGeoJson($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StMakeEnvelope(xmin: Expression, ymin: Expression, xmax: Expression, ymax: Expression) + extends Expression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullable: Boolean = true + override def children: Seq[Expression] = Seq(xmin, ymin, xmax, ymax) + override def eval(input: InternalRow): Any = { + val xv = xmin.eval(input) + val yv = ymin.eval(input) + val xv2 = xmax.eval(input) + val yv2 = ymax.eval(input) + if (xv == null || yv == null || xv2 == null || yv2 == null) { + null + } else { + UTF8String.fromString( + CometGeoFallback.makeEnvelope( + xv.toString.toDouble, + yv.toString.toDouble, + xv2.toString.toDouble, + yv2.toString.toDouble)) + } + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = + copy( + xmin = newChildren(0), + ymin = newChildren(1), + xmax = newChildren(2), + ymax = newChildren(3)) +} + +case class StMakeLine(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g1: Any, g2: Any): Any = + UTF8String.fromString(CometGeoFallback.makeLine(g1.toString, g2.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".makeLine($g1.toString(), $g2.toString()))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +// ---- Serializers --------------------------------------------------------- + +case class StAsText(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.asText(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.asText($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StAsGeoJson(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.asGeoJson(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.asGeoJson($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +// ---- Additional predicates ----------------------------------------------- + +case class StCovers(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.covers(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".covers($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StCoveredBy(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.coveredBy(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".coveredBy($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StEquals(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.equals(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".equals($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StTouches(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.touches(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".touches($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StCrosses(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.crosses(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".crosses($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StDisjoint(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.disjoint(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".disjoint($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StOverlaps(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = BooleanType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.overlaps(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".overlaps($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +// ---- Additional measurements --------------------------------------------- + +case class StDistanceSphere(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.distanceSphere(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".distanceSphere($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StPerimeter(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g: Any): Any = CometGeoFallback.perimeter(g.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.perimeter($g.toString())") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StHausdorffDistance(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = DoubleType + override def nullSafeEval(g1: Any, g2: Any): Any = + CometGeoFallback.hausdorffDistance(g1.toString, g2.toString) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".hausdorffDistance($g1.toString(), $g2.toString())") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +// ---- Additional transformations ------------------------------------------ + +case class StSimplifyPreserveTopology(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g1: Any, g2: Any): Any = + UTF8String.fromString( + CometGeoFallback.simplifyPreserveTopology(g1.toString, g2.toString.toDouble)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".simplifyPreserveTopology($g1.toString(), Double.parseDouble($g2.toString())))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StFlipCoordinates(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.flipCoordinates(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".flipCoordinates($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +case class StBoundary(child: Expression) extends UnaryExpression with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g: Any): Any = + UTF8String.fromString(CometGeoFallback.boundary(g.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + g => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$.boundary($g.toString()))") + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +// ---- Additional set operations ------------------------------------------- + +case class StDifference(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g1: Any, g2: Any): Any = + UTF8String.fromString(CometGeoFallback.difference(g1.toString, g2.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".difference($g1.toString(), $g2.toString()))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +case class StSymDifference(left: Expression, right: Expression) + extends BinaryExpression + with NullIntolerant { + override def dataType: DataType = StringType + override def nullSafeEval(g1: Any, g2: Any): Any = + UTF8String.fromString(CometGeoFallback.symDifference(g1.toString, g2.toString)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen( + ctx, + ev, + (g1, g2) => + s"org.apache.spark.unsafe.types.UTF8String.fromString(" + + s"org.apache.comet.expressions.CometGeoFallback$$.MODULE$$" + + s".symDifference($g1.toString(), $g2.toString()))") + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) +} + +// ---- Registration helpers for SparkSessionExtensions.injectFunction ------ + +object GeoExpressions { + + type FunctionDescription = + (FunctionIdentifier, ExpressionInfo, Seq[Expression] => Expression) + + private def desc( + name: String, + cls: Class[_], + builder: Seq[Expression] => Expression): FunctionDescription = + (new FunctionIdentifier(name), new ExpressionInfo(cls.getName, name), builder) + + val stContainsInfo: FunctionDescription = + desc("st_contains", classOf[StContains], { args => StContains(args(0), args(1)) }) + + val stIntersectsInfo: FunctionDescription = + desc("st_intersects", classOf[StIntersects], { args => StIntersects(args(0), args(1)) }) + + val stWithinInfo: FunctionDescription = + desc("st_within", classOf[StWithin], { args => StWithin(args(0), args(1)) }) + + val stDistanceInfo: FunctionDescription = + desc("st_distance", classOf[StDistance], { args => StDistance(args(0), args(1)) }) + + val stAreaInfo: FunctionDescription = + desc("st_area", classOf[StArea], { args => StArea(args(0)) }) + + val stCentroidInfo: FunctionDescription = + desc("st_centroid", classOf[StCentroid], { args => StCentroid(args(0)) }) + + val stLengthInfo: FunctionDescription = + desc("st_length", classOf[StLength], { args => StLength(args(0)) }) + + val stIsEmptyInfo: FunctionDescription = + desc("st_isempty", classOf[StIsEmpty], { args => StIsEmpty(args(0)) }) + + val stGeometryTypeInfo: FunctionDescription = + desc("st_geometrytype", classOf[StGeometryType], { args => StGeometryType(args(0)) }) + + val stNumPointsInfo: FunctionDescription = + desc("st_numpoints", classOf[StNumPoints], { args => StNumPoints(args(0)) }) + + val stXInfo: FunctionDescription = + desc("st_x", classOf[StX], { args => StX(args(0)) }) + + val stYInfo: FunctionDescription = + desc("st_y", classOf[StY], { args => StY(args(0)) }) + + val stEnvelopeInfo: FunctionDescription = + desc("st_envelope", classOf[StEnvelope], { args => StEnvelope(args(0)) }) + + val stConvexHullInfo: FunctionDescription = + desc("st_convexhull", classOf[StConvexHull], { args => StConvexHull(args(0)) }) + + val stSimplifyInfo: FunctionDescription = + desc("st_simplify", classOf[StSimplify], { args => StSimplify(args(0), args(1)) }) + + val stBufferInfo: FunctionDescription = + desc("st_buffer", classOf[StBuffer], { args => StBuffer(args(0), args(1)) }) + + val stUnionInfo: FunctionDescription = + desc("st_union", classOf[StUnion], { args => StUnion(args(0), args(1)) }) + + val stIntersectionInfo: FunctionDescription = + desc("st_intersection", classOf[StIntersection], { args => StIntersection(args(0), args(1)) }) + + val stGeomFromWktInfo: FunctionDescription = + desc("st_geomfromwkt", classOf[StGeomFromWkt], { args => StGeomFromWkt(args(0)) }) + + val stGeomFromGeoJsonInfo: FunctionDescription = + desc("st_geomfromgeojson", classOf[StGeomFromGeoJson], { args => StGeomFromGeoJson(args(0)) }) + + val stPointInfo: FunctionDescription = + desc("st_point", classOf[StPoint], { args => StPoint(args(0), args(1)) }) + + val stMakeEnvelopeInfo: FunctionDescription = + desc( + "st_makeenvelope", + classOf[StMakeEnvelope], + { args => StMakeEnvelope(args(0), args(1), args(2), args(3)) }) + + val stMakeLineInfo: FunctionDescription = + desc("st_makeline", classOf[StMakeLine], { args => StMakeLine(args(0), args(1)) }) + + val stAsTextInfo: FunctionDescription = + desc("st_astext", classOf[StAsText], { args => StAsText(args(0)) }) + + val stAsGeoJsonInfo: FunctionDescription = + desc("st_asgeojson", classOf[StAsGeoJson], { args => StAsGeoJson(args(0)) }) + + val stCoversInfo: FunctionDescription = + desc("st_covers", classOf[StCovers], { args => StCovers(args(0), args(1)) }) + + val stCoveredByInfo: FunctionDescription = + desc("st_coveredby", classOf[StCoveredBy], { args => StCoveredBy(args(0), args(1)) }) + + val stEqualsInfo: FunctionDescription = + desc("st_equals", classOf[StEquals], { args => StEquals(args(0), args(1)) }) + + val stTouchesInfo: FunctionDescription = + desc("st_touches", classOf[StTouches], { args => StTouches(args(0), args(1)) }) + + val stCrossesInfo: FunctionDescription = + desc("st_crosses", classOf[StCrosses], { args => StCrosses(args(0), args(1)) }) + + val stDisjointInfo: FunctionDescription = + desc("st_disjoint", classOf[StDisjoint], { args => StDisjoint(args(0), args(1)) }) + + val stOverlapsInfo: FunctionDescription = + desc("st_overlaps", classOf[StOverlaps], { args => StOverlaps(args(0), args(1)) }) + + val stDistanceSphereInfo: FunctionDescription = + desc( + "st_distancesphere", + classOf[StDistanceSphere], + { args => StDistanceSphere(args(0), args(1)) }) + + val stPerimeterInfo: FunctionDescription = + desc("st_perimeter", classOf[StPerimeter], { args => StPerimeter(args(0)) }) + + val stHausdorffDistanceInfo: FunctionDescription = + desc( + "st_hausdorffdistance", + classOf[StHausdorffDistance], + { args => StHausdorffDistance(args(0), args(1)) }) + + val stSimplifyPreserveTopologyInfo: FunctionDescription = + desc( + "st_simplifypreservetopology", + classOf[StSimplifyPreserveTopology], + { args => StSimplifyPreserveTopology(args(0), args(1)) }) + + val stFlipCoordinatesInfo: FunctionDescription = + desc("st_flipcoordinates", classOf[StFlipCoordinates], { args => StFlipCoordinates(args(0)) }) + + val stBoundaryInfo: FunctionDescription = + desc("st_boundary", classOf[StBoundary], { args => StBoundary(args(0)) }) + + val stDifferenceInfo: FunctionDescription = + desc("st_difference", classOf[StDifference], { args => StDifference(args(0), args(1)) }) + + val stSymDifferenceInfo: FunctionDescription = + desc( + "st_symdifference", + classOf[StSymDifference], + { args => StSymDifference(args(0), args(1)) }) +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8d48239e76..80ac5de050 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -274,6 +274,10 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[StaticInvoke] -> CometStaticInvoke, classOf[UnscaledValue] -> CometUnscaledValue) + // Native Comet geo expressions + optional Sedona ST_ expressions. + private[comet] val geoExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = + CometGeoExpr.buildSerdeMap() + /** * Mapping of Spark expression class to Comet expression handler. */ @@ -281,7 +285,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { mathExpressions ++ hashExpressions ++ stringExpressions ++ conditionalExpressions ++ mapExpressions ++ predicateExpressions ++ structExpressions ++ bitwiseExpressions ++ miscExpressions ++ arrayExpressions ++ - temporalExpressions ++ conversionExpressions ++ urlExpressions + temporalExpressions ++ conversionExpressions ++ urlExpressions ++ geoExpressions /** * Mapping of Spark aggregate expression class to Comet expression handler. diff --git a/spark/src/main/scala/org/apache/comet/serde/geo.scala b/spark/src/main/scala/org/apache/comet/serde/geo.scala new file mode 100644 index 0000000000..549be98a0d --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/geo.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.util.Try + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +import org.apache.comet.expressions.{StArea, StAsGeoJson, StAsText, StBoundary, StBuffer, StCentroid, StContains, StConvexHull, StCoveredBy, StCovers, StCrosses, StDifference, StDisjoint, StDistance, StDistanceSphere, StEnvelope, StEquals, StFlipCoordinates, StGeometryType, StGeomFromGeoJson, StGeomFromWkt, StHausdorffDistance, StIntersection, StIntersects, StIsEmpty, StLength, StMakeEnvelope, StMakeLine, StNumPoints, StOverlaps, StPerimeter, StPoint, StSimplify, StSimplifyPreserveTopology, StSymDifference, StTouches, StUnion, StWithin, StX, StY} +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} + +/** + * Serde for native Comet geo expressions and optional Sedona ST_ expressions. Maps each to the + * corresponding named ScalarFunc so the DataFusion planner resolves it to the Rust geo UDF. + * + * Sedona entries are added only when Sedona is present on the classpath at runtime. + */ +private[serde] object CometGeoExpr { + + def buildSerdeMap(): Map[Class[_ <: Expression], CometExpressionSerde[_]] = { + val nativeEntries: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + classOf[StContains] -> new CometGeoScalarFunc("st_contains"), + classOf[StIntersects] -> new CometGeoScalarFunc("st_intersects"), + classOf[StWithin] -> new CometGeoScalarFunc("st_within"), + classOf[StDistance] -> new CometGeoScalarFunc("st_distance"), + classOf[StArea] -> new CometGeoScalarFunc("st_area"), + classOf[StCentroid] -> new CometGeoScalarFunc("st_centroid"), + classOf[StLength] -> new CometGeoScalarFunc("st_length"), + classOf[StIsEmpty] -> new CometGeoScalarFunc("st_isempty"), + classOf[StGeometryType] -> new CometGeoScalarFunc("st_geometrytype"), + classOf[StNumPoints] -> new CometGeoScalarFunc("st_numpoints"), + classOf[StX] -> new CometGeoScalarFunc("st_x"), + classOf[StY] -> new CometGeoScalarFunc("st_y"), + classOf[StEnvelope] -> new CometGeoScalarFunc("st_envelope"), + classOf[StConvexHull] -> new CometGeoScalarFunc("st_convexhull"), + classOf[StSimplify] -> new CometGeoScalarFunc("st_simplify"), + classOf[StBuffer] -> new CometGeoScalarFunc("st_buffer"), + classOf[StUnion] -> new CometGeoScalarFunc("st_union"), + classOf[StIntersection] -> new CometGeoScalarFunc("st_intersection"), + classOf[StGeomFromWkt] -> new CometGeoScalarFunc("st_geomfromwkt"), + classOf[StGeomFromGeoJson] -> new CometGeoScalarFunc("st_geomfromgeojson"), + classOf[StPoint] -> new CometGeoScalarFunc("st_point"), + classOf[StMakeEnvelope] -> new CometGeoScalarFunc("st_makeenvelope"), + classOf[StMakeLine] -> new CometGeoScalarFunc("st_makeline"), + classOf[StAsText] -> new CometGeoScalarFunc("st_astext"), + classOf[StAsGeoJson] -> new CometGeoScalarFunc("st_asgeojson"), + classOf[StCovers] -> new CometGeoScalarFunc("st_covers"), + classOf[StCoveredBy] -> new CometGeoScalarFunc("st_coveredby"), + classOf[StEquals] -> new CometGeoScalarFunc("st_equals"), + classOf[StTouches] -> new CometGeoScalarFunc("st_touches"), + classOf[StCrosses] -> new CometGeoScalarFunc("st_crosses"), + classOf[StDisjoint] -> new CometGeoScalarFunc("st_disjoint"), + classOf[StOverlaps] -> new CometGeoScalarFunc("st_overlaps"), + classOf[StDistanceSphere] -> new CometGeoScalarFunc("st_distancesphere"), + classOf[StPerimeter] -> new CometGeoScalarFunc("st_perimeter"), + classOf[StHausdorffDistance] -> new CometGeoScalarFunc("st_hausdorffdistance"), + classOf[StSimplifyPreserveTopology] -> new CometGeoScalarFunc( + "st_simplifypreservetopology"), + classOf[StFlipCoordinates] -> new CometGeoScalarFunc("st_flipcoordinates"), + classOf[StBoundary] -> new CometGeoScalarFunc("st_boundary"), + classOf[StDifference] -> new CometGeoScalarFunc("st_difference"), + classOf[StSymDifference] -> new CometGeoScalarFunc("st_symdifference")) + + val sedonaEntries: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Seq( + "org.apache.sedona.sql.utils.expressions.ST_Contains" -> "st_contains", + "org.apache.sedona.sql.utils.expressions.ST_Intersects" -> "st_intersects", + "org.apache.sedona.sql.utils.expressions.ST_Distance" -> "st_distance", + "org.apache.sedona.sql.utils.expressions.ST_Within" -> "st_within", + "org.apache.sedona.sql.utils.expressions.ST_Area" -> "st_area", + "org.apache.sedona.sql.utils.expressions.ST_Centroid" -> "st_centroid", + "org.apache.sedona.sql.utils.expressions.ST_Length" -> "st_length", + "org.apache.sedona.sql.utils.expressions.ST_IsEmpty" -> "st_isempty", + "org.apache.sedona.sql.utils.expressions.ST_GeometryType" -> "st_geometrytype", + "org.apache.sedona.sql.utils.expressions.ST_NumPoints" -> "st_numpoints", + "org.apache.sedona.sql.utils.expressions.ST_X" -> "st_x", + "org.apache.sedona.sql.utils.expressions.ST_Y" -> "st_y", + "org.apache.sedona.sql.utils.expressions.ST_Envelope" -> "st_envelope", + "org.apache.sedona.sql.utils.expressions.ST_ConvexHull" -> "st_convexhull", + "org.apache.sedona.sql.utils.expressions.ST_Simplify" -> "st_simplify", + "org.apache.sedona.sql.utils.expressions.ST_Buffer" -> "st_buffer", + "org.apache.sedona.sql.utils.expressions.ST_Union" -> "st_union", + "org.apache.sedona.sql.utils.expressions.ST_Intersection" -> "st_intersection").flatMap { + case (className, funcName) => + // scalastyle:off classforname + Try(Class.forName(className).asInstanceOf[Class[Expression]]) + // scalastyle:on classforname + .toOption + .map(cls => cls -> new CometGeoScalarFunc(funcName)) + }.toMap + + nativeEntries ++ sedonaEntries + } +} + +/** + * Generic serde for a geo expression: emits ScalarFunc { func = funcName } so the DataFusion + * planner resolves it to the named Rust UDF registered in the SessionContext. + */ +private[serde] class CometGeoScalarFunc(funcName: String) + extends CometExpressionSerde[Expression] { + + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExprs = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProto(funcName, childExprs: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +}