Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ fn op_name(op: &OpStruct) -> &'static str {
OpStruct::Explode(_) => "Explode",
OpStruct::CsvScan(_) => "CsvScan",
OpStruct::ShuffleScan(_) => "ShuffleScan",
OpStruct::BroadcastNestedLoopJoin(_) => "BroadcastNestedLoopJoin",
}
}

Expand Down
35 changes: 35 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ use arrow::row::{OwnedRow, RowConverter, SortField};
use datafusion::common::utils::SingleRowListArrayBuilder;
use datafusion::common::UnnestOptions;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::joins::NestedLoopJoinExec;
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec};
use datafusion_comet_proto::spark_expression::ListLiteral;
Expand Down Expand Up @@ -1197,6 +1198,40 @@ impl PhysicalPlanner {
))
}
}

OpStruct::BroadcastNestedLoopJoin(bnlj) => {
let (join_params, scans, shuffle_scans) = self.parse_join_parameters(
inputs,
children,
&[],
&[],
bnlj.join_type,
&bnlj.condition,
partition_count,
)?;

let left = Arc::clone(&join_params.left.native_plan);
let right = Arc::clone(&join_params.right.native_plan);

let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
left,
right,
join_params.join_filter,
&join_params.join_type,
None,
)?);

Ok((
scans,
shuffle_scans,
Arc::new(SparkPlan::new(
spark_plan.plan_id,
nested_loop_join,
vec![join_params.left, join_params.right],
)),
))
}

OpStruct::Limit(limit) => {
assert_eq!(children.len(), 1);
let num = limit.limit;
Expand Down
1 change: 1 addition & 0 deletions native/core/src/execution/planner/operator_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,6 @@ fn get_operator_type(spark_operator: &Operator) -> Option<OperatorType> {
OpStruct::Explode(_) => None, // Not yet in OperatorType enum
OpStruct::CsvScan(_) => Some(OperatorType::CsvScan),
OpStruct::ShuffleScan(_) => None, // Not yet in OperatorType enum
OpStruct::BroadcastNestedLoopJoin(_) => None,
}
}
7 changes: 7 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ message Operator {
Explode explode = 114;
CsvScan csv_scan = 115;
ShuffleScan shuffle_scan = 116;
BroadcastNestedLoopJoin broadcast_nested_loop_join = 117;
}
}

Expand Down Expand Up @@ -384,6 +385,12 @@ message SortMergeJoin {
optional spark.spark_expression.Expr condition = 5;
}

message BroadcastNestedLoopJoin {
JoinType join_type = 1;
BuildSide build_side = 2;
optional spark.spark_expression.Expr condition = 3;
}

enum JoinType {
Inner = 0;
LeftOuter = 1;
Expand Down
2 changes: 2 additions & 0 deletions spark/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ object CometConf extends ShimCometConf {
createExecEnabledConfig("broadcastExchange", defaultValue = true)
val COMET_EXEC_HASH_JOIN_ENABLED: ConfigEntry[Boolean] =
createExecEnabledConfig("hashJoin", defaultValue = true)
val COMET_EXEC_BROADCAST_NESTED_LOOP_JOIN_ENABLED: ConfigEntry[Boolean] =
createExecEnabledConfig("broadcastNestedLoopJoin", defaultValue = true)
val COMET_EXEC_SORT_MERGE_JOIN_ENABLED: ConfigEntry[Boolean] =
createExecEnabledConfig("sortMergeJoin", defaultValue = true)
val COMET_EXEC_AGGREGATE_ENABLED: ConfigEntry[Boolean] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
import org.apache.spark.sql.execution.datasources.v2.json.JsonScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -81,6 +81,7 @@ object CometExecRule {
classOf[HashAggregateExec] -> CometHashAggregateExec,
classOf[ObjectHashAggregateExec] -> CometObjectHashAggregateExec,
classOf[BroadcastHashJoinExec] -> CometBroadcastHashJoinExec,
classOf[BroadcastNestedLoopJoinExec] -> CometBroadcastNestedLoopJoinExec,
classOf[ShuffledHashJoinExec] -> CometHashJoinExec,
classOf[SortMergeJoinExec] -> CometSortMergeJoinExec,
classOf[SortExec] -> CometSortExec,
Expand Down
111 changes: 110 additions & 1 deletion spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -1918,6 +1918,115 @@ trait CometHashJoin {
}
}

case class CometBroadcastNestedLoopJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
override val outputOrdering: Seq[SortOrder],
joinType: JoinType,
condition: Option[Expression],
buildSide: BuildSide,
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)
}

object CometBroadcastNestedLoopJoinExec extends CometOperatorSerde[BroadcastNestedLoopJoinExec] {

/**
* Get the optional Comet configuration entry that is used to enable or disable native support
* for this operator.
*/
override def enabledConfig: Option[ConfigEntry[Boolean]] = {
Some(CometConf.COMET_EXEC_BROADCAST_NESTED_LOOP_JOIN_ENABLED)
}

override def getSupportLevel(op: BroadcastNestedLoopJoinExec): SupportLevel =
Compatible(None)

/**
* Convert a Spark operator into a protocol buffer representation that can be passed into native
* code.
*
* @param op
* The Spark operator.
* @param builder
* The protobuf builder for the operator.
* @param childOp
* Child operators that have already been converted to Comet.
* @return
* Protocol buffer representation, or None if the operator could not be converted. In this
* case it is expected that the input operator will have been tagged with reasons why it could
* not be converted.
*/
override def convert(
op: BroadcastNestedLoopJoinExec,
builder: Operator.Builder,
childOp: Operator*): Option[Operator] = {

val buildSide = op.buildSide match {
case BuildLeft => OperatorOuterClass.BuildSide.BuildLeft
case BuildRight => OperatorOuterClass.BuildSide.BuildRight
}

val join = op.joinType
val joinType = {
import OperatorOuterClass.JoinType
join match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ =>
// Spark doesn't support other join types
withInfo(op, s"Unsupported join type $join")
return None
}
}

val joinCondition = op.condition.map({ cond =>
val condProto = exprToProto(cond, op.left.output ++ op.right.output)
if (condProto.isEmpty) {
withInfo(op, cond)
return None
}
condProto.get
})

val joinBuilder = OperatorOuterClass.BroadcastNestedLoopJoin
.newBuilder()
.setJoinType(joinType)
.setBuildSide(buildSide)

joinCondition.foreach(joinBuilder.setCondition)

Some(builder.setBroadcastNestedLoopJoin(joinBuilder).build())
}

override def createExec(
nativeOp: Operator,
op: BroadcastNestedLoopJoinExec): CometNativeExec = {

CometBroadcastNestedLoopJoinExec(
nativeOp,
op,
op.output,
op.outputOrdering,
op.joinType,
op.condition,
op.buildSide,
op.left,
op.right,
SerializedPlan(None))
}
}

object CometBroadcastHashJoinExec extends CometOperatorSerde[HashJoin] with CometHashJoin {

override def enabledConfig: Option[ConfigEntry[Boolean]] =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
CometNativeColumnarToRow
+- CometBroadcastNestedLoopJoin
:- CometBroadcastNestedLoopJoin
: :- CometBroadcastNestedLoopJoin
: : :- CometBroadcastNestedLoopJoin
: : : :- CometBroadcastNestedLoopJoin
: : : : :- CometHashAggregate
: : : : : +- CometExchange
: : : : : +- CometHashAggregate
: : : : : +- CometHashAggregate
: : : : : +- CometExchange
: : : : : +- CometHashAggregate
: : : : : +- CometProject
: : : : : +- CometFilter
: : : : : +- CometNativeScan parquet spark_catalog.default.store_sales
: : : : +- CometBroadcastExchange
: : : : +- CometHashAggregate
: : : : +- CometExchange
: : : : +- CometHashAggregate
: : : : +- CometHashAggregate
: : : : +- CometExchange
: : : : +- CometHashAggregate
: : : : +- CometProject
: : : : +- CometFilter
: : : : +- CometNativeScan parquet spark_catalog.default.store_sales
: : : +- CometBroadcastExchange
: : : +- CometHashAggregate
: : : +- CometExchange
: : : +- CometHashAggregate
: : : +- CometHashAggregate
: : : +- CometExchange
: : : +- CometHashAggregate
: : : +- CometProject
: : : +- CometFilter
: : : +- CometNativeScan parquet spark_catalog.default.store_sales
: : +- CometBroadcastExchange
: : +- CometHashAggregate
: : +- CometExchange
: : +- CometHashAggregate
: : +- CometHashAggregate
: : +- CometExchange
: : +- CometHashAggregate
: : +- CometProject
: : +- CometFilter
: : +- CometNativeScan parquet spark_catalog.default.store_sales
: +- CometBroadcastExchange
: +- CometHashAggregate
: +- CometExchange
: +- CometHashAggregate
: +- CometHashAggregate
: +- CometExchange
: +- CometHashAggregate
: +- CometProject
: +- CometFilter
: +- CometNativeScan parquet spark_catalog.default.store_sales
+- CometBroadcastExchange
+- CometHashAggregate
+- CometExchange
+- CometHashAggregate
+- CometHashAggregate
+- CometExchange
+- CometHashAggregate
+- CometProject
+- CometFilter
+- CometNativeScan parquet spark_catalog.default.store_sales

Comet accelerated 64 out of 64 eligible operators (100%). Final plan contains 1 transitions between Spark and Comet.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
CometNativeColumnarToRow
+- CometProject
+- CometBroadcastNestedLoopJoin
:- CometHashAggregate
: +- CometExchange
: +- CometHashAggregate
: +- CometProject
: +- CometBroadcastHashJoin
: :- CometProject
: : +- CometBroadcastHashJoin
: : :- CometProject
: : : +- CometBroadcastHashJoin
: : : :- CometProject
: : : : +- CometBroadcastHashJoin
: : : : :- CometProject
: : : : : +- CometBroadcastHashJoin
: : : : : :- CometProject
: : : : : : +- CometBroadcastHashJoin
: : : : : : :- CometFilter
: : : : : : : +- CometNativeScan parquet spark_catalog.default.store_sales
: : : : : : : +- CometSubqueryBroadcast
: : : : : : : +- CometBroadcastExchange
: : : : : : : +- CometProject
: : : : : : : +- CometFilter
: : : : : : : +- CometNativeScan parquet spark_catalog.default.date_dim
: : : : : : +- CometBroadcastExchange
: : : : : : +- CometProject
: : : : : : +- CometFilter
: : : : : : +- CometNativeScan parquet spark_catalog.default.store
: : : : : +- CometBroadcastExchange
: : : : : +- CometProject
: : : : : +- CometFilter
: : : : : +- CometNativeScan parquet spark_catalog.default.promotion
: : : : +- CometBroadcastExchange
: : : : +- CometProject
: : : : +- CometFilter
: : : : +- CometNativeScan parquet spark_catalog.default.date_dim
: : : +- CometBroadcastExchange
: : : +- CometFilter
: : : +- CometNativeScan parquet spark_catalog.default.customer
: : +- CometBroadcastExchange
: : +- CometProject
: : +- CometFilter
: : +- CometNativeScan parquet spark_catalog.default.customer_address
: +- CometBroadcastExchange
: +- CometProject
: +- CometFilter
: +- CometNativeScan parquet spark_catalog.default.item
+- CometBroadcastExchange
+- CometHashAggregate
+- CometExchange
+- CometHashAggregate
+- CometProject
+- CometBroadcastHashJoin
:- CometProject
: +- CometBroadcastHashJoin
: :- CometProject
: : +- CometBroadcastHashJoin
: : :- CometProject
: : : +- CometBroadcastHashJoin
: : : :- CometProject
: : : : +- CometBroadcastHashJoin
: : : : :- CometFilter
: : : : : +- CometNativeScan parquet spark_catalog.default.store_sales
: : : : : +- ReusedSubquery
: : : : +- CometBroadcastExchange
: : : : +- CometProject
: : : : +- CometFilter
: : : : +- CometNativeScan parquet spark_catalog.default.store
: : : +- CometBroadcastExchange
: : : +- CometProject
: : : +- CometFilter
: : : +- CometNativeScan parquet spark_catalog.default.date_dim
: : +- CometBroadcastExchange
: : +- CometFilter
: : +- CometNativeScan parquet spark_catalog.default.customer
: +- CometBroadcastExchange
: +- CometProject
: +- CometFilter
: +- CometNativeScan parquet spark_catalog.default.customer_address
+- CometBroadcastExchange
+- CometProject
+- CometFilter
+- CometNativeScan parquet spark_catalog.default.item

Comet accelerated 81 out of 83 eligible operators (97%). Final plan contains 1 transitions between Spark and Comet.
Loading
Loading