diff --git a/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp b/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp index 3fc105aa38b..0c56a0a619d 100644 --- a/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp +++ b/dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp @@ -261,6 +261,12 @@ Block ScanHashMapAfterProbeBlockInputStream::readImpl() else fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); break; + case ASTTableJoin::Kind::Full: + if (parent.has_other_condition) + fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); + else + fillColumnsUsingCurrentPartition(columns_left, columns_right, row_counter_column); + break; case ASTTableJoin::Kind::RightAnti: case ASTTableJoin::Kind::RightOuter: if (parent.has_other_condition) diff --git a/dbms/src/Debug/MockExecutor/JoinBinder.cpp b/dbms/src/Debug/MockExecutor/JoinBinder.cpp index 01011dd4e66..778064f0e0a 100644 --- a/dbms/src/Debug/MockExecutor/JoinBinder.cpp +++ b/dbms/src/Debug/MockExecutor/JoinBinder.cpp @@ -19,11 +19,63 @@ #include #include #include +#include #include #include namespace DB::mock { +namespace +{ +void appendJoinSchema(DAGSchema & output_schema, const DAGSchema & input_schema, bool make_nullable) +{ + for (const auto & field : input_schema) + { + if (make_nullable && field.second.hasNotNullFlag()) + output_schema.push_back(toNullableDAGColumnInfo(field)); + else + output_schema.push_back(field); + } +} + +void buildLeftSideJoinSchema(DAGSchema & schema, const DAGSchema & left_schema, tipb::JoinType tp) +{ + appendJoinSchema(schema, left_schema, JoinInterpreterHelper::makeLeftJoinSideNullable(tp)); +} + +void buildRightSideJoinSchema(DAGSchema & schema, const DAGSchema & right_schema, tipb::JoinType tp) +{ + /// Note: for semi join, the right table column is ignored + /// but for (anti) left outer semi join, a 1/0 (uint8) field is pushed back + /// indicating whether right table has matching row(s), see comment in ASTTableJoin::Kind for details. + if (tp == tipb::JoinType::TypeLeftOuterSemiJoin || tp == tipb::JoinType::TypeAntiLeftOuterSemiJoin) + { + tipb::FieldType field_type{}; + field_type.set_tp(TiDB::TypeTiny); + field_type.set_charset("binary"); + field_type.set_collate(TiDB::ITiDBCollator::BINARY); + field_type.set_flag(0); + field_type.set_flen(-1); + field_type.set_decimal(-1); + schema.push_back(std::make_pair("", TiDB::fieldTypeToColumnInfo(field_type))); + } + else if (tp != tipb::JoinType::TypeSemiJoin && tp != tipb::JoinType::TypeAntiSemiJoin) + { + appendJoinSchema(schema, right_schema, JoinInterpreterHelper::makeRightJoinSideNullable(tp)); + } +} + +DAGSchema buildOtherConditionSchema( + const DAGSchema & left_schema, + const DAGSchema & right_schema, + tipb::JoinType join_type) +{ + DAGSchema merged_children_schema; + appendJoinSchema(merged_children_schema, left_schema, JoinInterpreterHelper::makeLeftJoinSideNullable(join_type)); + appendJoinSchema(merged_children_schema, right_schema, JoinInterpreterHelper::makeRightJoinSideNullable(join_type)); + return merged_children_schema; +} +} // namespace void JoinBinder::addRuntimeFilter(MockRuntimeFilter & rf) { @@ -95,22 +147,8 @@ void JoinBinder::columnPrune(std::unordered_set & used_columns) /// update output schema output_schema.clear(); - - for (auto & field : children[0]->output_schema) - { - if (tp == tipb::TypeRightOuterJoin && field.second.hasNotNullFlag()) - output_schema.push_back(toNullableDAGColumnInfo(field)); - else - output_schema.push_back(field); - } - - for (auto & field : children[1]->output_schema) - { - if (tp == tipb::TypeLeftOuterJoin && field.second.hasNotNullFlag()) - output_schema.push_back(toNullableDAGColumnInfo(field)); - else - output_schema.push_back(field); - } + buildLeftSideJoinSchema(output_schema, children[0]->output_schema, tp); + buildRightSideJoinSchema(output_schema, children[1]->output_schema, tp); } void JoinBinder::fillJoinKeyAndFieldType( @@ -158,6 +196,7 @@ bool JoinBinder::toTiPBExecutor( join->set_join_exec_type(tipb::JoinExecType::TypeHashJoin); join->set_inner_idx(inner_index); join->set_is_null_aware_semi_join(is_null_aware_semi_join); + assert(is_null_eq.empty() || is_null_eq.size() == join_cols.size()); for (const auto & key : join_cols) { @@ -175,6 +214,9 @@ bool JoinBinder::toTiPBExecutor( collator_id); } + for (const auto flag : is_null_eq) + join->add_is_null_eq(flag != 0); + for (const auto & expr : left_conds) { tipb::Expr * cond = join->add_left_conditions(); @@ -187,11 +229,8 @@ bool JoinBinder::toTiPBExecutor( astToPB(children[1]->output_schema, expr, cond, collator_id, context); } - DAGSchema merged_children_schema{children[0]->output_schema}; - merged_children_schema.insert( - merged_children_schema.end(), - children[1]->output_schema.begin(), - children[1]->output_schema.end()); + DAGSchema merged_children_schema + = buildOtherConditionSchema(children[0]->output_schema, children[1]->output_schema, tp); for (const auto & expr : other_conds) { @@ -293,45 +332,6 @@ void JoinBinder::toMPPSubPlan( exchange_map[right_exchange_receiver->name] = std::make_pair(right_exchange_receiver, right_exchange_sender); } -static void buildLeftSideJoinSchema(DAGSchema & schema, const DAGSchema & left_schema, tipb::JoinType tp) -{ - for (const auto & field : left_schema) - { - if (tp == tipb::JoinType::TypeRightOuterJoin && field.second.hasNotNullFlag()) - schema.push_back(toNullableDAGColumnInfo(field)); - else - schema.push_back(field); - } -} - -static void buildRightSideJoinSchema(DAGSchema & schema, const DAGSchema & right_schema, tipb::JoinType tp) -{ - /// Note: for semi join, the right table column is ignored - /// but for (anti) left outer semi join, a 1/0 (uint8) field is pushed back - /// indicating whether right table has matching row(s), see comment in ASTTableJoin::Kind for details. - if (tp == tipb::JoinType::TypeLeftOuterSemiJoin || tp == tipb::JoinType::TypeAntiLeftOuterSemiJoin) - { - tipb::FieldType field_type{}; - field_type.set_tp(TiDB::TypeTiny); - field_type.set_charset("binary"); - field_type.set_collate(TiDB::ITiDBCollator::BINARY); - field_type.set_flag(0); - field_type.set_flen(-1); - field_type.set_decimal(-1); - schema.push_back(std::make_pair("", TiDB::fieldTypeToColumnInfo(field_type))); - } - else if (tp != tipb::JoinType::TypeSemiJoin && tp != tipb::JoinType::TypeAntiSemiJoin) - { - for (const auto & field : right_schema) - { - if (tp == tipb::JoinType::TypeLeftOuterJoin && field.second.hasNotNullFlag()) - schema.push_back(toNullableDAGColumnInfo(field)); - else - schema.push_back(field); - } - } -} - // compileJoin constructs a mocked Join executor node, note that all conditional expression params can be default ExecutorBinderPtr compileJoin( size_t & executor_index, @@ -339,6 +339,7 @@ ExecutorBinderPtr compileJoin( ExecutorBinderPtr right, tipb::JoinType tp, const ASTs & join_cols, + const std::vector & is_null_eq, const ASTs & left_conds, const ASTs & right_conds, const ASTs & other_conds, @@ -357,6 +358,7 @@ ExecutorBinderPtr compileJoin( output_schema, tp, join_cols, + is_null_eq, left_conds, right_conds, other_conds, @@ -405,6 +407,6 @@ ExecutorBinderPtr compileJoin(size_t & executor_index, ExecutorBinderPtr left, E join_cols.push_back(key); } } - return compileJoin(executor_index, left, right, tp, join_cols); + return compileJoin(executor_index, left, right, tp, join_cols, {}); } } // namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/JoinBinder.h b/dbms/src/Debug/MockExecutor/JoinBinder.h index 336183266a9..b6290a577be 100644 --- a/dbms/src/Debug/MockExecutor/JoinBinder.h +++ b/dbms/src/Debug/MockExecutor/JoinBinder.h @@ -29,6 +29,7 @@ class JoinBinder : public ExecutorBinder const DAGSchema & output_schema_, tipb::JoinType tp_, const ASTs & join_cols_, + const std::vector & is_null_eq_, const ASTs & l_conds, const ASTs & r_conds, const ASTs & o_conds, @@ -39,6 +40,7 @@ class JoinBinder : public ExecutorBinder : ExecutorBinder(index_, "Join_" + std::to_string(index_), output_schema_) , tp(tp_) , join_cols(join_cols_) + , is_null_eq(is_null_eq_) , left_conds(l_conds) , right_conds(r_conds) , other_conds(o_conds) @@ -77,6 +79,7 @@ class JoinBinder : public ExecutorBinder tipb::JoinType tp; const ASTs join_cols{}; + const std::vector is_null_eq{}; const ASTs left_conds{}; const ASTs right_conds{}; const ASTs other_conds{}; @@ -93,6 +96,7 @@ ExecutorBinderPtr compileJoin( ExecutorBinderPtr right, tipb::JoinType tp, const ASTs & join_cols, + const std::vector & is_null_eq = {}, const ASTs & left_conds = {}, const ASTs & right_conds = {}, const ASTs & other_conds = {}, diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 8c57f5031e3..e8530dfada4 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -468,7 +468,7 @@ const std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::JsonReplaceSig, "cast"}, //{tipb::ScalarFuncSig::JsonRemoveSig, "cast"}, //{tipb::ScalarFuncSig::JsonMergeSig, "cast"}, - //{tipb::ScalarFuncSig::JsonObjectSig, "cast"}, + {tipb::ScalarFuncSig::JsonObjectSig, "json_object"}, {tipb::ScalarFuncSig::JsonArraySig, "json_array"}, {tipb::ScalarFuncSig::JsonValidJsonSig, "json_valid_json"}, {tipb::ScalarFuncSig::JsonValidOthersSig, "json_valid_others"}, @@ -847,6 +847,8 @@ String getJoinTypeName(const tipb::JoinType & tp) return "LeftOuterJoin"; case tipb::JoinType::TypeRightOuterJoin: return "RightOuterJoin"; + case tipb::JoinType::TypeFullOuterJoin: + return "FullOuterJoin"; case tipb::JoinType::TypeLeftOuterSemiJoin: return "LeftOuterSemiJoin"; case tipb::JoinType::TypeAntiSemiJoin: diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp index 10f00045eac..45d4747af68 100644 --- a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp @@ -60,6 +60,7 @@ std::pair getJoinKindAndBuildSideIndex( /// 3. for non-cross left/right outer join, there is no problem in this swap. /// 4. for cross left outer join, the build side is always right, needn't and can't swap. /// 5. for cross right outer join, the build side is always left, so it will always swap and change to cross left outer join. + /// 6. for non-cross full outer join, keep full join kind and respect inner_idx as build side. /// note that whatever the build side is, we can't support cross-right-outer join now. static const std::unordered_map< std::pair, @@ -72,6 +73,8 @@ std::pair getJoinKindAndBuildSideIndex( {{tipb::JoinType::TypeLeftOuterJoin, 1}, {ASTTableJoin::Kind::LeftOuter, 1}}, {{tipb::JoinType::TypeRightOuterJoin, 0}, {ASTTableJoin::Kind::LeftOuter, 0}}, {{tipb::JoinType::TypeRightOuterJoin, 1}, {ASTTableJoin::Kind::RightOuter, 1}}, + {{tipb::JoinType::TypeFullOuterJoin, 0}, {ASTTableJoin::Kind::Full, 0}}, + {{tipb::JoinType::TypeFullOuterJoin, 1}, {ASTTableJoin::Kind::Full, 1}}, {{tipb::JoinType::TypeSemiJoin, 0}, {ASTTableJoin::Kind::RightSemi, 0}}, {{tipb::JoinType::TypeSemiJoin, 1}, {ASTTableJoin::Kind::Semi, 1}}, {{tipb::JoinType::TypeAntiSemiJoin, 0}, {ASTTableJoin::Kind::RightAnti, 0}}, @@ -103,6 +106,8 @@ std::pair getJoinKindAndBuildSideIndex( {{tipb::JoinType::TypeAntiLeftOuterSemiJoin, 1}, {ASTTableJoin::Kind::NullAware_LeftOuterAnti, 1}}}; RUNTIME_ASSERT(inner_index == 0 || inner_index == 1); + if (unlikely(tipb_join_type == tipb::JoinType::TypeFullOuterJoin && join_keys_size == 0)) + throw TiFlashException("Cartesian full outer join is not supported yet", Errors::Coprocessor::BadRequest); const auto & join_type_map = [is_null_aware, join_keys_size]() { if (is_null_aware) { @@ -183,6 +188,19 @@ JoinKeyTypes getJoinKeyTypes(const tipb::Join & join) return join_key_types; } +std::vector getJoinKeyNullEqFlags(const tipb::Join & join) +{ + if (unlikely(join.is_null_eq_size() != 0 && join.is_null_eq_size() != join.left_join_keys_size())) + throw TiFlashException( + "size of join.is_null_eq does not match size of join.left_join_keys/right_join_keys", + Errors::Coprocessor::BadRequest); + + std::vector is_null_eq(join.left_join_keys_size(), 0); + for (int i = 0; i < join.is_null_eq_size(); ++i) + is_null_eq[i] = join.is_null_eq(i) ? 1 : 0; + return is_null_eq; +} + TiDB::TiDBCollators getJoinKeyCollators(const tipb::Join & join, const JoinKeyTypes & join_key_types, bool is_test) { TiDB::TiDBCollators collators; @@ -212,7 +230,18 @@ TiFlashJoin::TiFlashJoin(const tipb::Join & join_, bool is_test) // NOLINT(cppco : join(join_) , join_key_types(getJoinKeyTypes(join_)) , join_key_collators(getJoinKeyCollators(join_, join_key_types, is_test)) + , is_null_eq(getJoinKeyNullEqFlags(join_)) { + if (unlikely(join.is_null_aware_semi_join())) + { + for (auto flag : is_null_eq) + { + if (flag != 0) + throw TiFlashException( + "NullEQ join keys are incompatible with null-aware semi join", + Errors::Coprocessor::BadRequest); + } + } std::tie(kind, build_side_index) = getJoinKindAndBuildSideIndex(join); } @@ -295,8 +324,8 @@ NamesAndTypes TiFlashJoin::genColumnsForOtherJoinFilter( column_set_for_origin_columns.emplace(p.name); } }; - append_origin_columns(left_cols, join.join_type() == tipb::JoinType::TypeRightOuterJoin); - append_origin_columns(right_cols, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + append_origin_columns(left_cols, makeLeftJoinSideNullable(join.join_type())); + append_origin_columns(right_cols, makeRightJoinSideNullable(join.join_type())); /// append the columns generated by probe side prepare join actions. /// the new columns are @@ -310,8 +339,8 @@ NamesAndTypes TiFlashJoin::genColumnsForOtherJoinFilter( columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); } }; - bool make_nullable = build_side_index == 1 ? join.join_type() == tipb::JoinType::TypeRightOuterJoin - : join.join_type() == tipb::JoinType::TypeLeftOuterJoin; + bool make_nullable = build_side_index == 1 ? makeLeftJoinSideNullable(join.join_type()) + : makeRightJoinSideNullable(join.join_type()); append_new_columns(probe_prepare_join_actions->getSampleBlock(), make_nullable); return columns_for_other_join_filter; @@ -330,11 +359,11 @@ NamesAndTypes TiFlashJoin::genJoinOutputColumns( } }; - append_output_columns(left_cols, join.join_type() == tipb::JoinType::TypeRightOuterJoin); + append_output_columns(left_cols, makeLeftJoinSideNullable(join.join_type())); if (!isSemiFamily() && !isLeftOuterSemiFamily()) { /// for (left outer) semi join, the columns from right table will be ignored - append_output_columns(right_cols, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + append_output_columns(right_cols, makeRightJoinSideNullable(join.join_type())); } if (!match_helper_name.empty()) @@ -407,6 +436,43 @@ std::tuple prepareJoin( return {chain.getLastActions(), std::move(key_names), std::move(original_key_names), std::move(filter_column_name)}; } +void alignNullEqKeyTypes( + const std::vector & is_null_eq, + const ExpressionActionsPtr & probe_prepare_actions, + Names & probe_key_names, + const ExpressionActionsPtr & build_prepare_actions, + Names & build_key_names) +{ + RUNTIME_CHECK(probe_key_names.size() == build_key_names.size()); + RUNTIME_CHECK(probe_key_names.size() == is_null_eq.size()); + + for (size_t i = 0; i < is_null_eq.size(); ++i) + { + if (is_null_eq[i] == 0) + continue; + + const auto & probe_type = probe_prepare_actions->getSampleBlock().getByName(probe_key_names[i]).type; + const auto & build_type = build_prepare_actions->getSampleBlock().getByName(build_key_names[i]).type; + if (probe_type->equals(*build_type)) + continue; + + RUNTIME_CHECK_MSG( + removeNullable(probe_type)->equals(*removeNullable(build_type)), + "NullEQ key type mismatch after prepareJoin is not a pure nullability mismatch: probe={} build={}", + probe_type->getName(), + build_type->getName()); + + if (!probe_type->isNullable()) + { + probe_prepare_actions->add(ExpressionAction::convertToNullable(probe_key_names[i])); + } + if (!build_type->isNullable()) + { + build_prepare_actions->add(ExpressionAction::convertToNullable(build_key_names[i])); + } + } +} + std::vector TiFlashJoin::genRuntimeFilterList( const Context & context, const NamesAndTypes & source_columns, @@ -447,6 +513,25 @@ std::vector TiFlashJoin::genRuntimeFilterList( return result; } +bool TiFlashJoin::shouldDisableRuntimeFilter( + const ExpressionActionsPtr & build_prepare_actions, + const Names & build_key_names) const +{ + RUNTIME_CHECK(build_prepare_actions != nullptr); + RUNTIME_CHECK(build_key_names.size() == is_null_eq.size()); + + const auto & sample_block = build_prepare_actions->getSampleBlock(); + for (size_t i = 0; i < is_null_eq.size(); ++i) + { + if (is_null_eq[i] == 0) + continue; + + if (sample_block.getByName(build_key_names[i]).type->isNullable()) + return true; + } + return false; +} + NamesAndTypes genDAGExpressionAnalyzerSourceColumns(Block block, const NamesAndTypes & tidb_schema) { /// generate source_columns that is used to compile tipb::Expr, the rule is columns in `tidb_schema` diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h index 4745c6e34ce..5a44476d463 100644 --- a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h @@ -94,9 +94,9 @@ struct JoinNonEqualConditions /// Validate this JoinNonEqualConditions and return error message if any. const char * validate(ASTTableJoin::Kind kind) const { - if unlikely (!left_filter_column.empty() && !isLeftOuterJoin(kind)) + if unlikely (!left_filter_column.empty() && !(isLeftOuterJoin(kind) || kind == ASTTableJoin::Kind::Full)) return "non left join with left conditions"; - if unlikely (!right_filter_column.empty() && !isRightOuterJoin(kind)) + if unlikely (!right_filter_column.empty() && !(isRightOuterJoin(kind) || kind == ASTTableJoin::Kind::Full)) return "non right join with right conditions"; if unlikely ((!other_cond_name.empty() || !other_eq_cond_from_in_name.empty()) && other_cond_expr == nullptr) @@ -119,6 +119,16 @@ struct JoinNonEqualConditions namespace JoinInterpreterHelper { +constexpr bool makeLeftJoinSideNullable(tipb::JoinType join_type) +{ + return join_type == tipb::JoinType::TypeRightOuterJoin || join_type == tipb::JoinType::TypeFullOuterJoin; +} + +constexpr bool makeRightJoinSideNullable(tipb::JoinType join_type) +{ + return join_type == tipb::JoinType::TypeLeftOuterJoin || join_type == tipb::JoinType::TypeFullOuterJoin; +} + struct TiFlashJoin { TiFlashJoin(const tipb::Join & join_, bool is_test); @@ -130,6 +140,7 @@ struct TiFlashJoin JoinKeyTypes join_key_types; TiDB::TiDBCollators join_key_collators; + std::vector is_null_eq; /// (cartesian) (anti) left outer semi join. bool isLeftOuterSemiFamily() const @@ -207,6 +218,9 @@ struct TiFlashJoin const NamesAndTypes & source_columns, const std::unordered_map & key_names_map, const LoggerPtr & log); + + bool shouldDisableRuntimeFilter(const ExpressionActionsPtr & build_prepare_actions, const Names & build_key_names) + const; }; /// @join_prepare_expr_actions: generates join key columns and join filter column @@ -220,6 +234,13 @@ std::tuple prepareJoin( const JoinKeyTypes & join_key_types, const google::protobuf::RepeatedPtrField & filters); +void alignNullEqKeyTypes( + const std::vector & is_null_eq, + const ExpressionActionsPtr & probe_prepare_actions, + Names & probe_key_names, + const ExpressionActionsPtr & build_prepare_actions, + Names & build_key_names); + /// generate source_columns that is used to compile tipb::Expr, the rule is columns in `tidb_schema` /// must be the first part of the source_columns NamesAndTypes genDAGExpressionAnalyzerSourceColumns(Block block, const NamesAndTypes & tidb_schema); diff --git a/dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp b/dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp index 6c9ec8d1ed0..ddfbd1f89aa 100644 --- a/dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp +++ b/dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -172,9 +173,9 @@ bool collectForJoin(std::vector & output_field_types, const tip // collect output_field_types for join self for (auto & field_type : children_output_field_types[0]) { - if (executor.join().join_type() == tipb::JoinType::TypeRightOuterJoin) + if (JoinInterpreterHelper::makeLeftJoinSideNullable(executor.join().join_type())) { - /// the type of left column for right join is always nullable + /// the type of left column for right/full join is always nullable auto updated_field_type = field_type; updated_field_type.set_flag( static_cast(updated_field_type.flag()) & (~static_cast(TiDB::ColumnFlagNotNull))); @@ -210,9 +211,9 @@ bool collectForJoin(std::vector & output_field_types, const tip /// for semi/anti semi join, the right table column is ignored for (auto & field_type : children_output_field_types[1]) { - if (executor.join().join_type() == tipb::JoinType::TypeLeftOuterJoin) + if (JoinInterpreterHelper::makeRightJoinSideNullable(executor.join().join_type())) { - /// the type of right column for left join is always nullable + /// the type of right column for left/full join is always nullable auto updated_field_type = field_type; updated_field_type.set_flag( updated_field_type.flag() & (~static_cast(TiDB::ColumnFlagNotNull))); diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_join_get_kind_and_build_index.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_join_get_kind_and_build_index.cpp index 4cfa8db2369..ffe2248a21c 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_join_get_kind_and_build_index.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_join_get_kind_and_build_index.cpp @@ -13,6 +13,8 @@ // limitations under the License. #include +#include +#include #include #include @@ -26,6 +28,38 @@ class JoinKindAndBuildIndexTestRunner : public testing::Test { }; +namespace +{ +tipb::Expr makeJoinKeyWithFieldType() +{ + tipb::Expr expr; + expr.mutable_field_type()->set_tp(TiDB::TypeLong); + return expr; +} + +tipb::Join makeFullOuterJoinForSchemaTest(size_t inner_index) +{ + tipb::Join join; + join.set_join_type(tipb::JoinType::TypeFullOuterJoin); + join.set_inner_idx(inner_index); + *join.add_left_join_keys() = makeJoinKeyWithFieldType(); + *join.add_right_join_keys() = makeJoinKeyWithFieldType(); + return join; +} + +tipb::Join makeNullAwareJoinWithNullEq() +{ + tipb::Join join; + join.set_join_type(tipb::JoinType::TypeAntiSemiJoin); + join.set_inner_idx(1); + join.set_is_null_aware_semi_join(true); + *join.add_left_join_keys() = makeJoinKeyWithFieldType(); + *join.add_right_join_keys() = makeJoinKeyWithFieldType(); + join.add_is_null_eq(true); + return join; +} +} // namespace + bool invalidParams(tipb::JoinType tipb_join_type, size_t inner_index, bool is_null_aware, size_t join_keys_size) { try @@ -39,6 +73,33 @@ bool invalidParams(tipb::JoinType tipb_join_type, size_t inner_index, bool is_nu } } +String getErrorMessage(tipb::JoinType tipb_join_type, size_t inner_index, bool is_null_aware, size_t join_keys_size) +{ + try + { + JoinInterpreterHelper::getJoinKindAndBuildSideIndex(tipb_join_type, inner_index, is_null_aware, join_keys_size); + return ""; + } + catch (Exception & e) + { + return e.message(); + } +} + +String getTiFlashJoinErrorMessage(const tipb::Join & join) +{ + try + { + JoinInterpreterHelper::TiFlashJoin tiflash_join(join, false); + static_cast(tiflash_join); + return ""; + } + catch (Exception & e) + { + return e.message(); + } +} + TEST(JoinKindAndBuildIndexTestRunner, TestNullAwareJoins) { auto result = JoinInterpreterHelper::getJoinKindAndBuildSideIndex(tipb::JoinType::TypeAntiSemiJoin, 1, true, 1); @@ -59,6 +120,178 @@ TEST(JoinKindAndBuildIndexTestRunner, TestNullAwareJoins) ASSERT_TRUE(invalidParams(tipb::JoinType::TypeAntiLeftOuterSemiJoin, 0, true, 1)); } +TEST(JoinKindAndBuildIndexTestRunner, TestNullAwareJoinRejectsNullEqKeys) +{ + auto error_message = getTiFlashJoinErrorMessage(makeNullAwareJoinWithNullEq()); + ASSERT_FALSE(error_message.empty()); + ASSERT_NE(error_message.find("NullEQ"), String::npos); +} + +TEST(JoinKindAndBuildIndexTestRunner, TestNullEqAlignsMixedNullabilityKeySchema) +{ + try + { + auto int_type = std::make_shared(); + auto nullable_int_type = makeNullable(int_type); + auto context = TiFlashTestEnv::getContext(); + + ColumnWithTypeAndName probe_column{nullptr, int_type, "probe_k"}; + ColumnWithTypeAndName build_column{nullptr, nullable_int_type, "build_k"}; + + tipb::Join join; + join.set_join_type(tipb::JoinType::TypeInnerJoin); + join.set_inner_idx(1); + *join.add_left_join_keys() = columnToTiPBExpr(probe_column, 0); + *join.add_right_join_keys() = columnToTiPBExpr(build_column, 0); + join.add_is_null_eq(true); + + JoinInterpreterHelper::TiFlashJoin tiflash_join(join, true); + + NamesAndTypes probe_source_columns{{probe_column.name, probe_column.type}}; + NamesAndTypes build_source_columns{{build_column.name, build_column.type}}; + + auto [probe_prepare_actions, probe_key_names, original_probe_key_names, probe_filter_column_name] + = JoinInterpreterHelper::prepareJoin( + *context, + probe_source_columns, + tiflash_join.getProbeJoinKeys(), + tiflash_join.join_key_types, + tiflash_join.getProbeConditions()); + auto [build_prepare_actions, build_key_names, original_build_key_names, build_filter_column_name] + = JoinInterpreterHelper::prepareJoin( + *context, + build_source_columns, + tiflash_join.getBuildJoinKeys(), + tiflash_join.join_key_types, + tiflash_join.getBuildConditions()); + + ASSERT_FALSE(probe_prepare_actions->getSampleBlock().getByName(probe_key_names[0]).type->isNullable()); + ASSERT_TRUE(build_prepare_actions->getSampleBlock().getByName(build_key_names[0]).type->isNullable()); + + JoinInterpreterHelper::alignNullEqKeyTypes( + tiflash_join.is_null_eq, + probe_prepare_actions, + probe_key_names, + build_prepare_actions, + build_key_names); + + ASSERT_TRUE(probe_prepare_actions->getSampleBlock().getByName(probe_key_names[0]).type->isNullable()); + ASSERT_TRUE(build_prepare_actions->getSampleBlock().getByName(build_key_names[0]).type->isNullable()); + ASSERT_TRUE(probe_prepare_actions->getSampleBlock() + .getByName(probe_key_names[0]) + .type->equals(*build_prepare_actions->getSampleBlock().getByName(build_key_names[0]).type)); + } + catch (Exception & e) + { + FAIL() << e.message(); + } +} + +TEST(JoinKindAndBuildIndexTestRunner, TestNullableNullEqDisablesRuntimeFilter) +{ + try + { + auto int_type = std::make_shared(); + auto nullable_int_type = makeNullable(int_type); + auto context = TiFlashTestEnv::getContext(); + + ColumnWithTypeAndName probe_column{nullptr, int_type, "probe_k"}; + ColumnWithTypeAndName build_column{nullptr, nullable_int_type, "build_k"}; + + tipb::Join join; + join.set_join_type(tipb::JoinType::TypeInnerJoin); + join.set_inner_idx(1); + *join.add_left_join_keys() = columnToTiPBExpr(probe_column, 0); + *join.add_right_join_keys() = columnToTiPBExpr(build_column, 0); + join.add_is_null_eq(true); + + JoinInterpreterHelper::TiFlashJoin tiflash_join(join, true); + + NamesAndTypes probe_source_columns{{probe_column.name, probe_column.type}}; + NamesAndTypes build_source_columns{{build_column.name, build_column.type}}; + + auto [probe_prepare_actions, probe_key_names, original_probe_key_names, probe_filter_column_name] + = JoinInterpreterHelper::prepareJoin( + *context, + probe_source_columns, + tiflash_join.getProbeJoinKeys(), + tiflash_join.join_key_types, + tiflash_join.getProbeConditions()); + auto [build_prepare_actions, build_key_names, original_build_key_names, build_filter_column_name] + = JoinInterpreterHelper::prepareJoin( + *context, + build_source_columns, + tiflash_join.getBuildJoinKeys(), + tiflash_join.join_key_types, + tiflash_join.getBuildConditions()); + + JoinInterpreterHelper::alignNullEqKeyTypes( + tiflash_join.is_null_eq, + probe_prepare_actions, + probe_key_names, + build_prepare_actions, + build_key_names); + + ASSERT_TRUE(tiflash_join.shouldDisableRuntimeFilter(build_prepare_actions, build_key_names)); + } + catch (Exception & e) + { + FAIL() << e.message(); + } +} + +TEST(JoinKindAndBuildIndexTestRunner, TestNonNullableNullEqKeepsRuntimeFilterEnabled) +{ + try + { + auto int_type = std::make_shared(); + auto context = TiFlashTestEnv::getContext(); + + ColumnWithTypeAndName probe_column{nullptr, int_type, "probe_k"}; + ColumnWithTypeAndName build_column{nullptr, int_type, "build_k"}; + + tipb::Join join; + join.set_join_type(tipb::JoinType::TypeInnerJoin); + join.set_inner_idx(1); + *join.add_left_join_keys() = columnToTiPBExpr(probe_column, 0); + *join.add_right_join_keys() = columnToTiPBExpr(build_column, 0); + join.add_is_null_eq(true); + + JoinInterpreterHelper::TiFlashJoin tiflash_join(join, true); + + NamesAndTypes probe_source_columns{{probe_column.name, probe_column.type}}; + NamesAndTypes build_source_columns{{build_column.name, build_column.type}}; + + auto [probe_prepare_actions, probe_key_names, original_probe_key_names, probe_filter_column_name] + = JoinInterpreterHelper::prepareJoin( + *context, + probe_source_columns, + tiflash_join.getProbeJoinKeys(), + tiflash_join.join_key_types, + tiflash_join.getProbeConditions()); + auto [build_prepare_actions, build_key_names, original_build_key_names, build_filter_column_name] + = JoinInterpreterHelper::prepareJoin( + *context, + build_source_columns, + tiflash_join.getBuildJoinKeys(), + tiflash_join.join_key_types, + tiflash_join.getBuildConditions()); + + JoinInterpreterHelper::alignNullEqKeyTypes( + tiflash_join.is_null_eq, + probe_prepare_actions, + probe_key_names, + build_prepare_actions, + build_key_names); + + ASSERT_FALSE(tiflash_join.shouldDisableRuntimeFilter(build_prepare_actions, build_key_names)); + } + catch (Exception & e) + { + FAIL() << e.message(); + } +} + TEST(JoinKindAndBuildIndexTestRunner, TestCrossJoins) { /// Cross Inner Join, both sides supported @@ -97,6 +330,11 @@ TEST(JoinKindAndBuildIndexTestRunner, TestCrossJoins) ASSERT_TRUE(invalidParams(tipb::JoinType::TypeLeftOuterSemiJoin, 0, false, 0)); ASSERT_TRUE(invalidParams(tipb::JoinType::TypeAntiLeftOuterSemiJoin, 0, false, 0)); + + /// Cross FullOuterJoin is out of scope in this round and should fail with a clear message. + auto error_message = getErrorMessage(tipb::JoinType::TypeFullOuterJoin, 0, false, 0); + ASSERT_FALSE(error_message.empty()); + ASSERT_NE(error_message.find("Cartesian full outer join"), String::npos); } TEST(JoinKindAndBuildIndexTestRunner, TestEqualJoins) @@ -119,6 +357,12 @@ TEST(JoinKindAndBuildIndexTestRunner, TestEqualJoins) result = JoinInterpreterHelper::getJoinKindAndBuildSideIndex(tipb::JoinType::TypeRightOuterJoin, 1, false, 1); ASSERT_TRUE(result.first == ASTTableJoin::Kind::RightOuter && result.second == 1); + /// FullOuterJoin, keep full join kind and respect inner_idx as build side. + result = JoinInterpreterHelper::getJoinKindAndBuildSideIndex(tipb::JoinType::TypeFullOuterJoin, 0, false, 1); + ASSERT_TRUE(result.first == ASTTableJoin::Kind::Full && result.second == 0); + result = JoinInterpreterHelper::getJoinKindAndBuildSideIndex(tipb::JoinType::TypeFullOuterJoin, 1, false, 1); + ASSERT_TRUE(result.first == ASTTableJoin::Kind::Full && result.second == 1); + /// Semi/Anti result = JoinInterpreterHelper::getJoinKindAndBuildSideIndex(tipb::JoinType::TypeSemiJoin, 1, false, 1); ASSERT_TRUE(result.first == ASTTableJoin::Kind::Semi && result.second == 1); @@ -140,5 +384,63 @@ TEST(JoinKindAndBuildIndexTestRunner, TestEqualJoins) ASSERT_TRUE(invalidParams(tipb::JoinType::TypeAntiLeftOuterSemiJoin, 0, false, 1)); } +TEST(JoinKindAndBuildIndexTestRunner, TestFullJoinOutputColumnsAreNullable) +{ + auto join = makeFullOuterJoinForSchemaTest(1); + JoinInterpreterHelper::TiFlashJoin tiflash_join(join, false); + + auto int_type = std::make_shared(); + NamesAndTypes left_cols{{"l.a", int_type}, {"l.b", int_type}}; + NamesAndTypes right_cols{{"r.a", int_type}, {"r.b", int_type}}; + + auto join_output_columns = tiflash_join.genJoinOutputColumns(left_cols, right_cols, ""); + ASSERT_EQ(join_output_columns.size(), 4); + for (const auto & column : join_output_columns) + ASSERT_TRUE(column.type->isNullable()) << column.name; +} + +TEST(JoinKindAndBuildIndexTestRunner, TestFullJoinOtherConditionColumnsAreNullable) +{ + auto int_type = std::make_shared(); + NamesAndTypes left_cols{{"l.a", int_type}, {"l.b", int_type}}; + NamesAndTypes right_cols{{"r.a", int_type}, {"r.b", int_type}}; + + for (size_t inner_index : {size_t{0}, size_t{1}}) + { + auto join = makeFullOuterJoinForSchemaTest(inner_index); + JoinInterpreterHelper::TiFlashJoin tiflash_join(join, false); + + NamesAndTypes probe_prepare_columns = inner_index == 1 + ? NamesAndTypes{{"l.a", int_type}, {"l.b", int_type}, {"probe_extra", int_type}} + : NamesAndTypes{{"r.a", int_type}, {"r.b", int_type}, {"probe_extra", int_type}}; + auto probe_prepare_join_actions = std::make_shared(probe_prepare_columns); + + auto columns_for_other_join_filter + = tiflash_join.genColumnsForOtherJoinFilter(left_cols, right_cols, probe_prepare_join_actions); + ASSERT_EQ(columns_for_other_join_filter.size(), 5); + ASSERT_EQ(columns_for_other_join_filter.back().name, "probe_extra"); + for (const auto & column : columns_for_other_join_filter) + ASSERT_TRUE(column.type->isNullable()) << column.name; + } +} + +TEST(JoinKindAndBuildIndexTestRunner, TestFullJoinAllowsLeftAndRightConditions) +{ + JoinNonEqualConditions full_conditions; + full_conditions.left_filter_column = "left_cond"; + full_conditions.right_filter_column = "right_cond"; + ASSERT_EQ(full_conditions.validate(ASTTableJoin::Kind::Full), nullptr); + + JoinNonEqualConditions left_only_conditions; + left_only_conditions.left_filter_column = "left_cond"; + ASSERT_EQ(left_only_conditions.validate(ASTTableJoin::Kind::LeftOuter), nullptr); + ASSERT_STREQ(left_only_conditions.validate(ASTTableJoin::Kind::Inner), "non left join with left conditions"); + + JoinNonEqualConditions right_only_conditions; + right_only_conditions.right_filter_column = "right_cond"; + ASSERT_EQ(right_only_conditions.validate(ASTTableJoin::Kind::RightOuter), nullptr); + ASSERT_STREQ(right_only_conditions.validate(ASTTableJoin::Kind::Inner), "non right join with right conditions"); +} + } // namespace tests } // namespace DB diff --git a/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp b/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp index ba0c43c0edf..60758042af1 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp @@ -123,6 +123,13 @@ PhysicalPlanNodePtr PhysicalJoin::build( original_build_key_names, join_non_equal_conditions); + JoinInterpreterHelper::alignNullEqKeyTypes( + tiflash_join.is_null_eq, + probe_side_prepare_actions, + probe_key_names, + build_side_prepare_actions, + build_key_names); + const Settings & settings = context.getSettingsRef(); size_t max_bytes_before_external_join = settings.max_bytes_before_external_join; auto join_req_id = fmt::format("{}_{}", log->identifier(), executor_id); @@ -152,20 +159,29 @@ PhysicalPlanNodePtr PhysicalJoin::build( right_input_header, join_non_equal_conditions.other_cond_expr != nullptr); - assert(build_key_names.size() == original_build_key_names.size()); - std::unordered_map build_key_names_map; - for (size_t i = 0; i < original_build_key_names.size(); ++i) + std::vector runtime_filter_list; + if (tiflash_join.shouldDisableRuntimeFilter(build_side_prepare_actions, build_key_names)) { - build_key_names_map[original_build_key_names[i]] = build_key_names[i]; + LOG_INFO(log, "Disable runtime filter because a nullable NullEQ build key is present"); + } + else + { + assert(build_key_names.size() == original_build_key_names.size()); + std::unordered_map build_key_names_map; + for (size_t i = 0; i < original_build_key_names.size(); ++i) + { + build_key_names_map[original_build_key_names[i]] = build_key_names[i]; + } + runtime_filter_list + = tiflash_join.genRuntimeFilterList(context, build_source_columns, build_key_names_map, log); } - auto runtime_filter_list - = tiflash_join.genRuntimeFilterList(context, build_source_columns, build_key_names_map, log); LOG_DEBUG(log, "before register runtime filter list, list size:{}", runtime_filter_list.size()); context.getDAGContext()->runtime_filter_mgr.registerRuntimeFilterList(runtime_filter_list); JoinPtr join_ptr = std::make_shared( probe_key_names, build_key_names, + tiflash_join.is_null_eq, tiflash_join.kind, join_req_id, fine_grained_shuffle.stream_count, diff --git a/dbms/src/Flash/tests/gtest_join_executor.cpp b/dbms/src/Flash/tests/gtest_join_executor.cpp index f58a6062e60..145c3f59f84 100644 --- a/dbms/src/Flash/tests/gtest_join_executor.cpp +++ b/dbms/src/Flash/tests/gtest_join_executor.cpp @@ -4261,6 +4261,210 @@ try } CATCH +TEST_F(JoinExecutorTestRunner, FullOuterJoinWithOtherCondition) +try +{ + using tipb::JoinType; + + const std::vector> + test_cases + = {{ + {toNullableVec("a", {1}), toNullableVec("c", {10})}, + {toNullableVec("a", {1}), toNullableVec("c", {5})}, + {toNullableVec("a", {1, {}}), + toNullableVec("c", {10, {}}), + toNullableVec("a", {{}, 1}), + toNullableVec("c", {{}, 5})}, + "key matched but other condition failed for all rows", + }, + { + {toNullableVec("a", {2}), toNullableVec("c", {20})}, + {toNullableVec("a", {1}), toNullableVec("c", {5})}, + {toNullableVec("a", {2, {}}), + toNullableVec("c", {20, {}}), + toNullableVec("a", {{}, 1}), + toNullableVec("c", {{}, 5})}, + "key not matched", + }, + { + {toNullableVec("a", {{}}), toNullableVec("c", {10})}, + {toNullableVec("a", {1}), toNullableVec("c", {5})}, + {toNullableVec("a", {{}, {}}), + toNullableVec("c", {10, {}}), + toNullableVec("a", {{}, 1}), + toNullableVec("c", {{}, 5})}, + "probe-side null key should still be emitted as unmatched row", + }, + { + {toNullableVec("a", {1}), toNullableVec("c", {3})}, + {toNullableVec("a", {1, 1}), toNullableVec("c", {2, 4})}, + {toNullableVec("a", {1, {}}), + toNullableVec("c", {3, {}}), + toNullableVec("a", {1, 1}), + toNullableVec("c", {4, 2})}, + "only rows that pass other condition should be marked used", + }}; + + for (const auto & [left, right, res, test_case_name] : test_cases) + { + SCOPED_TRACE(test_case_name); + context.addMockTable( + "full_outer_other_condition", + "t", + {{"a", TiDB::TP::TypeLong}, {"c", TiDB::TP::TypeLong}}, + left); + context.addMockTable( + "full_outer_other_condition", + "s", + {{"a", TiDB::TP::TypeLong}, {"c", TiDB::TP::TypeLong}}, + right); + + auto request = context.scan("full_outer_other_condition", "t") + .join( + context.scan("full_outer_other_condition", "s"), + JoinType::TypeFullOuterJoin, + {col("a")}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {}, + 0, + false, + 1) + .build(context); + executeAndAssertColumnsEqual(request, res); + + auto request_column_prune = context.scan("full_outer_other_condition", "t") + .join( + context.scan("full_outer_other_condition", "s"), + JoinType::TypeFullOuterJoin, + {col("a")}, + {}, + {}, + {lt(col("t.c"), col("s.c"))}, + {}, + 0, + false, + 1) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(res), executeStreams(request_column_prune, 2)); + } +} +CATCH + +TEST_F(JoinExecutorTestRunner, FullOuterJoinWithoutOtherConditionAndNullKey) +try +{ + using tipb::JoinType; + + context.addMockTable( + "full_outer_no_other_condition", + "t", + {{"a", TiDB::TP::TypeLong}, {"c", TiDB::TP::TypeLong}}, + {toNullableVec("a", {1, 2, {}, 4}), toNullableVec("c", {10, 20, 30, 40})}); + context.addMockTable( + "full_outer_no_other_condition", + "s", + {{"a", TiDB::TP::TypeLong}, {"c", TiDB::TP::TypeLong}}, + {toNullableVec("a", {1, 3, {}}), toNullableVec("c", {100, 300, 400})}); + + auto request = context.scan("full_outer_no_other_condition", "t") + .join( + context.scan("full_outer_no_other_condition", "s"), + JoinType::TypeFullOuterJoin, + {col("a")}, + {}, + {}, + {}, + {}, + 0, + false, + 1) + .build(context); + + const ColumnsWithTypeAndName expected = { + toNullableVec({1, 2, {}, 4, {}, {}}), + toNullableVec({10, 20, 30, 40, {}, {}}), + toNullableVec({1, {}, {}, {}, {}, 3}), + toNullableVec({100, {}, {}, {}, 400, 300}), + }; + executeAndAssertColumnsEqual(request, expected); + + auto request_column_prune = context.scan("full_outer_no_other_condition", "t") + .join( + context.scan("full_outer_no_other_condition", "s"), + JoinType::TypeFullOuterJoin, + {col("a")}, + {}, + {}, + {}, + {}, + 0, + false, + 1) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(expected), executeStreams(request_column_prune, 2)); +} +CATCH + +TEST_F(JoinExecutorTestRunner, FullOuterJoinWithLeftAndRightConditions) +try +{ + using tipb::JoinType; + + context.addMockTable( + "full_outer_with_side_conditions", + "t", + {{"a", TiDB::TP::TypeLong}, {"c", TiDB::TP::TypeLong}}, + {toNullableVec("a", {1, 2}), toNullableVec("c", {5, 20})}); + context.addMockTable( + "full_outer_with_side_conditions", + "s", + {{"a", TiDB::TP::TypeLong}, {"c", TiDB::TP::TypeLong}}, + {toNullableVec("a", {1, 3}), toNullableVec("c", {200, 50})}); + + auto request = context.scan("full_outer_with_side_conditions", "t") + .join( + context.scan("full_outer_with_side_conditions", "s"), + JoinType::TypeFullOuterJoin, + {col("a")}, + {gt(col("t.c"), lit(Field(static_cast(10))))}, + {gt(col("s.c"), lit(Field(static_cast(100))))}, + {}, + {}, + 0, + false, + 1) + .build(context); + + const ColumnsWithTypeAndName expected = { + toNullableVec({1, 2, {}, {}}), + toNullableVec({5, 20, {}, {}}), + toNullableVec({{}, {}, 3, 1}), + toNullableVec({{}, {}, 50, 200}), + }; + executeAndAssertColumnsEqual(request, expected); + + auto request_column_prune = context.scan("full_outer_with_side_conditions", "t") + .join( + context.scan("full_outer_with_side_conditions", "s"), + JoinType::TypeFullOuterJoin, + {col("a")}, + {gt(col("t.c"), lit(Field(static_cast(10))))}, + {gt(col("s.c"), lit(Field(static_cast(100))))}, + {}, + {}, + 0, + false, + 1) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(expected), executeStreams(request_column_prune, 2)); +} +CATCH + #undef WRAP_FOR_JOIN_TEST_BEGIN #undef WRAP_FOR_JOIN_TEST_END diff --git a/dbms/src/Flash/tests/gtest_spill_join.cpp b/dbms/src/Flash/tests/gtest_spill_join.cpp index 75097500d91..5a008cfa99f 100644 --- a/dbms/src/Flash/tests/gtest_spill_join.cpp +++ b/dbms/src/Flash/tests/gtest_spill_join.cpp @@ -15,6 +15,7 @@ #include #include +#include namespace DB { @@ -30,6 +31,50 @@ class SpillJoinTestRunner : public DB::tests::JoinTestRunner } }; +namespace +{ +constexpr auto * null_eq_join_db = "null_eq_join_test"; +constexpr auto * null_eq_left_table = "left_nullable_table"; +constexpr auto * null_eq_right_table = "right_nullable_table"; +constexpr auto * null_eq_right_exchange = "right_nullable_exchange"; +constexpr size_t null_eq_rows = 4096; +constexpr size_t null_eq_fgs_stream_count = 5; + +ColumnsWithTypeAndName makeNullEqJoinColumns(size_t rows, Int32 value_shift) +{ + std::vector keys; + std::vector values; + keys.reserve(rows); + values.reserve(rows); + for (size_t i = 0; i < rows; ++i) + { + if (i % 64 == 0) + keys.emplace_back(std::nullopt); + else + keys.emplace_back(static_cast(i)); + values.emplace_back(static_cast((i + value_shift) % 11)); + } + return {toNullableVec("k", keys), toNullableVec("v", values)}; +} + +void addNullEqJoinSources(MockDAGRequestContext & context) +{ + MockColumnInfoVec column_infos{{"k", TiDB::TP::TypeLong}, {"v", TiDB::TP::TypeLong}}; + MockColumnInfoVec partition_column_infos{{"k", TiDB::TP::TypeLong}}; + auto left_columns = makeNullEqJoinColumns(null_eq_rows, 1); + auto right_columns = makeNullEqJoinColumns(null_eq_rows, 5); + + context.addMockTable(null_eq_join_db, null_eq_left_table, column_infos, left_columns, 10); + context.addMockTable(null_eq_join_db, null_eq_right_table, column_infos, right_columns, 10); + context.addExchangeReceiver( + null_eq_right_exchange, + column_infos, + right_columns, + null_eq_fgs_stream_count, + partition_column_infos); +} +} // namespace + #define WRAP_FOR_SPILL_TEST_BEGIN \ std::vector pipeline_bools{false, true}; \ for (auto enable_pipeline : pipeline_bools) \ @@ -575,6 +620,172 @@ try } CATCH +TEST_F(SpillJoinTestRunner, FullOuterJoinWithOtherConditionSpill) +try +{ + UInt64 max_block_size = 800; + size_t original_max_streams = 20; + UInt64 max_bytes_before_external_join = 20000; + String left_table_name = "left_table_10_concurrency"; + String right_table_name = "right_table_10_concurrency"; + + WRAP_FOR_SPILL_TEST_BEGIN + auto request = context.scan("outer_join_test", left_table_name) + .join( + context.scan("outer_join_test", right_table_name), + tipb::JoinType::TypeFullOuterJoin, + {col("a")}, + {}, + {}, + {lt(col(left_table_name + ".b"), col(right_table_name + ".b"))}, + {}, + 0, + false, + 1) + .project( + {fmt::format("{}.a", left_table_name), + fmt::format("{}.b", left_table_name), + fmt::format("{}.a", right_table_name), + fmt::format("{}.b", right_table_name)}) + .build(context); + auto request_column_prune = context.scan("outer_join_test", left_table_name) + .join( + context.scan("outer_join_test", right_table_name), + tipb::JoinType::TypeFullOuterJoin, + {col("a")}, + {}, + {}, + {lt(col(left_table_name + ".b"), col(right_table_name + ".b"))}, + {}, + 0, + false, + 1) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + + context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); + context.context->setSetting("max_bytes_before_external_join", Field(static_cast(0))); + auto ref_columns = executeStreams(request, original_max_streams); + + context.context->setSetting( + "max_bytes_before_external_join", + Field(static_cast(max_bytes_before_external_join))); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(ref_columns), executeStreams(request_column_prune, 2)); + WRAP_FOR_SPILL_TEST_END +} +CATCH + +TEST_F(SpillJoinTestRunner, FullOuterJoinWithOtherConditionNullEqSpill) +try +{ + constexpr UInt64 max_block_size = 800; + constexpr size_t original_max_streams = 20; + constexpr UInt64 max_bytes_before_external_join = 20000; + addNullEqJoinSources(context); + + auto left_key = fmt::format("{}.k", null_eq_left_table); + auto left_value = fmt::format("{}.v", null_eq_left_table); + auto right_key = fmt::format("{}.k", null_eq_right_table); + auto right_value = fmt::format("{}.v", null_eq_right_table); + + WRAP_FOR_SPILL_TEST_BEGIN + auto request = context.scan(null_eq_join_db, null_eq_left_table) + .join( + context.scan(null_eq_join_db, null_eq_right_table), + tipb::JoinType::TypeFullOuterJoin, + {col("k")}, + {}, + {}, + {lt(col(left_value), col(right_value))}, + {}, + 0, + false, + 1, + {1}) + .project({left_key, left_value, right_key, right_value}) + .build(context); + auto request_column_prune = context.scan(null_eq_join_db, null_eq_left_table) + .join( + context.scan(null_eq_join_db, null_eq_right_table), + tipb::JoinType::TypeFullOuterJoin, + {col("k")}, + {}, + {}, + {lt(col(left_value), col(right_value))}, + {}, + 0, + false, + 1, + {1}) + .aggregation({Count(lit(static_cast(1)))}, {}) + .build(context); + + context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); + context.context->setSetting("max_bytes_before_external_join", Field(static_cast(0))); + auto ref_columns = executeStreams(request, original_max_streams); + + context.context->setSetting( + "max_bytes_before_external_join", + Field(static_cast(max_bytes_before_external_join))); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); + ASSERT_COLUMNS_EQ_UR(genScalarCountResults(ref_columns), executeStreams(request_column_prune, 2)); + WRAP_FOR_SPILL_TEST_END +} +CATCH + +TEST_F(SpillJoinTestRunner, FineGrainedShuffleNullEqJoin) +try +{ + constexpr size_t original_max_streams = 20; + constexpr size_t original_max_streams_small = 4; + addNullEqJoinSources(context); + + auto left_key = fmt::format("{}.k", null_eq_left_table); + auto left_value = fmt::format("{}.v", null_eq_left_table); + auto right_value = fmt::format("{}.v", null_eq_right_table); + auto exchange_right_value = fmt::format("{}.v", null_eq_right_exchange); + + WRAP_FOR_SPILL_TEST_BEGIN + context.context->setSetting("max_bytes_before_external_join", Field(static_cast(0))); + auto reference = context.scan(null_eq_join_db, null_eq_left_table) + .join( + context.scan(null_eq_join_db, null_eq_right_table), + tipb::JoinType::TypeInnerJoin, + {col("k")}, + {}, + {}, + {}, + {}, + 0, + false, + 1, + {1}) + .project({left_key, left_value, right_value}) + .build(context); + auto ref_columns = executeStreams(reference, original_max_streams); + + auto request = context.scan(null_eq_join_db, null_eq_left_table) + .join( + context.receive(null_eq_right_exchange, null_eq_fgs_stream_count), + tipb::JoinType::TypeInnerJoin, + {col("k")}, + {}, + {}, + {}, + {}, + null_eq_fgs_stream_count, + false, + 1, + {1}) + .project({left_key, left_value, exchange_right_value}) + .build(context); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams_small)); + WRAP_FOR_SPILL_TEST_END +} +CATCH + #undef WRAP_FOR_SPILL_TEST_BEGIN #undef WRAP_FOR_SPILL_TEST_END diff --git a/dbms/src/Functions/FunctionsJson.cpp b/dbms/src/Functions/FunctionsJson.cpp index 15b12b3178c..fb1106e6256 100644 --- a/dbms/src/Functions/FunctionsJson.cpp +++ b/dbms/src/Functions/FunctionsJson.cpp @@ -24,6 +24,7 @@ void registerFunctionsJson(FunctionFactory & factory) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); diff --git a/dbms/src/Functions/FunctionsJson.h b/dbms/src/Functions/FunctionsJson.h index e72e792150b..0054a6c4f40 100644 --- a/dbms/src/Functions/FunctionsJson.h +++ b/dbms/src/Functions/FunctionsJson.h @@ -42,7 +42,9 @@ #include #include +#include #include +#include #include #include @@ -976,6 +978,186 @@ class FunctionJsonArray : public IFunction }; +class FunctionJsonObject : public IFunction +{ +public: + static constexpr auto name = "json_object"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 0; } + + bool isVariadic() const override { return true; } + + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForConstants() const override { return true; } + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (unlikely(arguments.size() % 2 != 0)) + { + throw Exception( + fmt::format("Incorrect parameter count in the call to native function '{}'", getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + for (const auto arg_idx : ext::range(0, arguments.size())) + { + if (!arguments[arg_idx]->onlyNull()) + { + const auto * arg = removeNullable(arguments[arg_idx]).get(); + if (!arg->isStringOrFixedString()) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument {} of function {}", + arg->getName(), + arg_idx + 1, + getName()); + } + } + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + if (arguments.empty()) + { + // clang-format off + const UInt8 empty_object_json_value[] = { + JsonBinary::TYPE_CODE_OBJECT, // object_type + 0x0, 0x0, 0x0, 0x0, // element_count + 0x8, 0x0, 0x0, 0x0}; // total_size + // clang-format on + auto empty_object_json = ColumnString::create(); + empty_object_json->insertData( + reinterpret_cast(empty_object_json_value), + sizeof(empty_object_json_value) / sizeof(UInt8)); + block.getByPosition(result).column = ColumnConst::create(std::move(empty_object_json), block.rows()); + return; + } + + auto nested_block = createBlockWithNestedColumns(block, arguments); + StringSources sources; + for (auto column_number : arguments) + { + sources.push_back( + block.getByPosition(column_number).column->onlyNull() + ? nullptr + : createDynamicStringSource(*nested_block.getByPosition(column_number).column)); + } + + auto rows = block.rows(); + auto col_to = ColumnString::create(); + auto & data_to = col_to->getChars(); + auto & offsets_to = col_to->getOffsets(); + offsets_to.resize(rows); + + std::vector nullmaps; + nullmaps.reserve(sources.size()); + bool is_input_nullable = false; + for (auto column_number : arguments) + { + const auto & col = block.getByPosition(column_number).column; + if (col->isColumnNullable()) + { + const auto & column_nullable = static_cast(*col); + nullmaps.push_back(&(column_nullable.getNullMapData())); + is_input_nullable = true; + } + else + { + nullmaps.push_back(nullptr); + } + } + + if (is_input_nullable) + doExecuteImpl(sources, rows, data_to, offsets_to, nullmaps); + else + doExecuteImpl(sources, rows, data_to, offsets_to, nullmaps); + + block.getByPosition(result).column = std::move(col_to); + } + +private: + template + static void doExecuteImpl( + StringSources & sources, + size_t rows, + ColumnString::Chars_t & data_to, + ColumnString::Offsets & offsets_to, + const std::vector & nullmaps) + { + const size_t pair_count = sources.size() / 2; + size_t reserve_size = rows * (1 + pair_count * 16); + for (const auto & source : sources) + reserve_size += source ? source->getSizeForReserve() : rows; + JsonBinary::JsonBinaryWriteBuffer write_buffer(data_to, reserve_size); + + std::map key_value_map; + std::vector keys; + std::vector values; + keys.reserve(pair_count); + values.reserve(pair_count); + + for (size_t i = 0; i < rows; ++i) + { + key_value_map.clear(); + for (size_t col = 0; col < sources.size(); col += 2) + { + if constexpr (is_input_nullable) + { + const auto * key_nullmap = nullmaps[col]; + if (!sources[col] || (key_nullmap && (*key_nullmap)[i])) + throw Exception("JSON documents may not contain NULL member names."); + } + + assert(sources[col]); + const auto & key_from = sources[col]->getWhole(); + if (unlikely(key_from.size > std::numeric_limits::max())) + throw Exception("TiDB/TiFlash does not yet support JSON objects with the key length >= 65536"); + String key(reinterpret_cast(key_from.data), key_from.size); + + JsonBinary value(JsonBinary::TYPE_CODE_LITERAL, StringRef(&JsonBinary::LITERAL_NIL, 1)); + if constexpr (is_input_nullable) + { + const auto * value_nullmap = nullmaps[col + 1]; + if (sources[col + 1] && !(value_nullmap && (*value_nullmap)[i])) + { + const auto & data_from = sources[col + 1]->getWhole(); + value = JsonBinary(data_from.data[0], StringRef(&data_from.data[1], data_from.size - 1)); + } + } + else + { + assert(sources[col + 1]); + const auto & data_from = sources[col + 1]->getWhole(); + value = JsonBinary(data_from.data[0], StringRef(&data_from.data[1], data_from.size - 1)); + } + + key_value_map.insert_or_assign(std::move(key), value); + } + + keys.clear(); + values.clear(); + for (const auto & [key, value] : key_value_map) + { + keys.emplace_back(key.data(), key.size()); + values.push_back(value); + } + + JsonBinary::buildBinaryJsonObjectInBuffer(keys, values, write_buffer); + writeChar(0, write_buffer); + offsets_to[i] = write_buffer.count(); + + for (const auto & source : sources) + { + if (source) + source->next(); + } + } + } +}; + + class FunctionCastJsonAsJson : public IFunction { public: diff --git a/dbms/src/Functions/tests/gtest_json_object.cpp b/dbms/src/Functions/tests/gtest_json_object.cpp new file mode 100644 index 00000000000..f8222081e0d --- /dev/null +++ b/dbms/src/Functions/tests/gtest_json_object.cpp @@ -0,0 +1,98 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include + +namespace DB::tests +{ +class TestJsonObject : public DB::tests::FunctionTest +{ +public: + ColumnWithTypeAndName castStringToJson(const ColumnWithTypeAndName & column) + { + assert(removeNullable(column.type)->isString()); + ColumnsWithTypeAndName inputs{column}; + return executeFunction("cast_string_as_json", inputs, nullptr, true); + } + + ColumnWithTypeAndName executeFunctionWithCast( + const ColumnNumbers & argument_column_numbers, + const ColumnsWithTypeAndName & columns) + { + auto json_column = executeFunction("json_object", argument_column_numbers, columns); + tipb::FieldType field_type; + field_type.set_flen(-1); + field_type.set_collate(TiDB::ITiDBCollator::BINARY); + field_type.set_tp(TiDB::TypeString); + return executeCastJsonAsStringFunction(json_column, field_type); + } +}; + +TEST_F(TestJsonObject, TestBasicSemantics) +try +{ + constexpr size_t rows_count = 2; + + { + ColumnsWithTypeAndName inputs{createColumn({"placeholder", "placeholder"})}; + auto res = executeFunctionWithCast({}, inputs); + ASSERT_COLUMN_EQ(createConstColumn>(rows_count, "{}"), res); + } + + { + ColumnsWithTypeAndName inputs{ + createColumn({"b", "b"}), + castStringToJson(createColumn({"1", "1"})), + createColumn({"a", "a"}), + castStringToJson(createColumn>({{}, "\"x\""})), + }; + auto res = executeFunctionWithCast({0, 1, 2, 3}, inputs); + auto expect = createColumn>({R"({"a": null, "b": 1})", R"({"a": "x", "b": 1})"}); + ASSERT_COLUMN_EQ(expect, res); + } + + { + ColumnsWithTypeAndName inputs{ + createConstColumn(rows_count, "dup"), + castStringToJson(createConstColumn(rows_count, "1")), + createConstColumn(rows_count, "dup"), + castStringToJson(createColumn({"2", "3"})), + }; + auto res = executeFunctionWithCast({0, 1, 2, 3}, inputs); + auto expect = createColumn>({R"({"dup": 2})", R"({"dup": 3})"}); + ASSERT_COLUMN_EQ(expect, res); + } +} +CATCH + +TEST_F(TestJsonObject, TestErrors) +try +{ + ASSERT_THROW(executeFunction("json_object", {createColumn({"a"})}), Exception); + + auto value = castStringToJson(createColumn({"1"})); + ASSERT_THROW(executeFunction("json_object", {createColumn>({{}}), value}), Exception); + + String too_long_key(std::numeric_limits::max() + 1, 'a'); + ASSERT_THROW(executeFunction("json_object", {createColumn({too_long_key}), value}), Exception); +} +CATCH + +} // namespace DB::tests diff --git a/dbms/src/Interpreters/CrossJoinProbeHelper.cpp b/dbms/src/Interpreters/CrossJoinProbeHelper.cpp index 73f5881a659..8c464432963 100644 --- a/dbms/src/Interpreters/CrossJoinProbeHelper.cpp +++ b/dbms/src/Interpreters/CrossJoinProbeHelper.cpp @@ -410,7 +410,7 @@ struct CrossJoinAdder } }; -template +template Block crossProbeBlockDeepCopyRightBlockImpl(ProbeProcessInfo & probe_process_info, const Blocks & right_blocks) { size_t num_existing_columns = probe_process_info.cross_join_data->left_column_index_in_left_block.size(); @@ -444,9 +444,9 @@ Block crossProbeBlockDeepCopyRightBlockImpl(ProbeProcessInfo & probe_process_inf for (; current_row < block_rows; ++current_row) { - if constexpr (has_null_map) + if constexpr (has_row_filter_map) { - if ((*probe_process_info.null_map)[current_row]) + if ((*probe_process_info.row_filter_map)[current_row]) { /// filter out by left_conditions, so just treated as not joined column block_full = CrossJoinAdder::addNotFound( @@ -501,7 +501,7 @@ Block crossProbeBlockDeepCopyRightBlockImpl(ProbeProcessInfo & probe_process_inf return probe_process_info.cross_join_data->result_block_schema.cloneWithColumns(std::move(dst_columns)); } -template +template std::pair crossProbeBlockShallowCopyRightBlockAddNotMatchedRows(ProbeProcessInfo & probe_process_info) { size_t num_existing_columns = probe_process_info.cross_join_data->left_column_index_in_left_block.size(); @@ -523,12 +523,12 @@ std::pair crossProbeBlockShallowCopyRightBlockAddNotMatchedRows(Pro .column.get(); } IColumn::Filter::value_type filter_column_value{}; - if constexpr (has_null_map) + if constexpr (has_row_filter_map) { - // todo use column->filter(null_map) to construct the result block in batch + // todo use column->filter(row_filter_map) to construct the result block in batch for (size_t i = 0; i < probe_process_info.block.rows(); ++i) { - if ((*probe_process_info.null_map)[i]) + if ((*probe_process_info.row_filter_map)[i]) { CrossJoinAdder::addNotFound( dst_columns, @@ -565,7 +565,7 @@ std::pair crossProbeBlockShallowCopyRightBlockAddNotMatchedRows(Pro return {probe_process_info.cross_join_data->result_block_schema.cloneWithColumns(std::move(dst_columns)), false}; } -template +template std::pair crossProbeBlockShallowCopyRightBlockImpl( ProbeProcessInfo & probe_process_info, const Blocks & right_blocks) @@ -574,11 +574,11 @@ std::pair crossProbeBlockShallowCopyRightBlockImpl( assert(probe_process_info.offsets_to_replicate != nullptr); size_t num_existing_columns = probe_process_info.cross_join_data->left_column_index_in_left_block.size(); - if constexpr (has_null_map) + if constexpr (has_row_filter_map) { /// skip filtered rows, the filtered rows will be handled at the end of this block while (probe_process_info.start_row < probe_process_info.block.rows() - && (*probe_process_info.null_map)[probe_process_info.start_row]) + && (*probe_process_info.row_filter_map)[probe_process_info.start_row]) { ++probe_process_info.start_row; } @@ -587,7 +587,7 @@ std::pair crossProbeBlockShallowCopyRightBlockImpl( if (probe_process_info.start_row == probe_process_info.block.rows()) { /// current probe block is done, collect un-matched rows - return crossProbeBlockShallowCopyRightBlockAddNotMatchedRows( + return crossProbeBlockShallowCopyRightBlockAddNotMatchedRows( probe_process_info); } assert(probe_process_info.cross_join_data->next_right_block_index < right_blocks.size()); @@ -645,45 +645,55 @@ Block crossProbeBlockDeepCopyRightBlock( { using enum ASTTableJoin::Strictness; using enum ASTTableJoin::Kind; -#define DISPATCH(HAS_NULL_MAP) \ - if (kind == Cross && strictness == All) \ - return crossProbeBlockDeepCopyRightBlockImpl(probe_process_info, right_blocks); \ - else if (kind == Cross_LeftOuter && strictness == All) \ - return crossProbeBlockDeepCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuter && strictness == Any) \ - return crossProbeBlockDeepCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_Semi && strictness == All) \ - return crossProbeBlockDeepCopyRightBlockImpl(probe_process_info, right_blocks); \ - else if (kind == Cross_Semi && strictness == Any) \ - return crossProbeBlockDeepCopyRightBlockImpl(probe_process_info, right_blocks); \ - else if (kind == Cross_Anti && strictness == All) \ - return crossProbeBlockDeepCopyRightBlockImpl(probe_process_info, right_blocks); \ - else if (kind == Cross_Anti && strictness == Any) \ - return crossProbeBlockDeepCopyRightBlockImpl(probe_process_info, right_blocks); \ - else if (kind == Cross_LeftOuterSemi && strictness == All) \ - return crossProbeBlockDeepCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuterSemi && strictness == Any) \ - return crossProbeBlockDeepCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuterAnti && strictness == All) \ - return crossProbeBlockDeepCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuterAnti && strictness == Any) \ - return crossProbeBlockDeepCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else \ +#define DISPATCH(HAS_ROW_FILTER_MAP) \ + if (kind == Cross && strictness == All) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuter && strictness == All) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuter && strictness == Any) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Semi && strictness == All) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Semi && strictness == Any) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Anti && strictness == All) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Anti && strictness == Any) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterSemi && strictness == All) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterSemi && strictness == Any) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterAnti && strictness == All) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterAnti && strictness == Any) \ + return crossProbeBlockDeepCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else \ throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); - if (probe_process_info.null_map) + if (probe_process_info.row_filter_map) { DISPATCH(true) } @@ -702,53 +712,55 @@ std::pair crossProbeBlockShallowCopyRightBlock( { using enum ASTTableJoin::Strictness; using enum ASTTableJoin::Kind; -#define DISPATCH(HAS_NULL_MAP) \ - if (kind == Cross && strictness == All) \ - return crossProbeBlockShallowCopyRightBlockImpl(probe_process_info, right_blocks); \ - else if (kind == Cross_LeftOuter && strictness == All) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuter && strictness == Any) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_Semi && strictness == All) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_Semi && strictness == Any) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_Anti && strictness == All) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_Anti && strictness == Any) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuterSemi && strictness == All) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuterSemi && strictness == Any) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuterAnti && strictness == All) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else if (kind == Cross_LeftOuterAnti && strictness == Any) \ - return crossProbeBlockShallowCopyRightBlockImpl( \ - probe_process_info, \ - right_blocks); \ - else \ +#define DISPATCH(HAS_ROW_FILTER_MAP) \ + if (kind == Cross && strictness == All) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuter && strictness == All) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuter && strictness == Any) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Semi && strictness == All) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Semi && strictness == Any) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Anti && strictness == All) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_Anti && strictness == Any) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterSemi && strictness == All) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterSemi && strictness == Any) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterAnti && strictness == All) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else if (kind == Cross_LeftOuterAnti && strictness == Any) \ + return crossProbeBlockShallowCopyRightBlockImpl( \ + probe_process_info, \ + right_blocks); \ + else \ throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); - if (probe_process_info.null_map) + if (probe_process_info.row_filter_map) { DISPATCH(true) } diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index 873772014f8..8a19af16b3b 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -51,22 +51,36 @@ extern const int TYPE_MISMATCH; namespace { -ColumnRawPtrs getKeyColumns(const Names & key_names, const Block & block) +ColumnRawPtrs getKeyColumns(const Names & key_names, const Block & block, const std::vector & is_null_eq = {}) { size_t keys_size = key_names.size(); + RUNTIME_CHECK(is_null_eq.empty() || is_null_eq.size() == keys_size); ColumnRawPtrs key_columns(keys_size); for (size_t i = 0; i < keys_size; ++i) { key_columns[i] = block.getByName(key_names[i]).column.get(); - /// We will join only keys, where all components are not NULL. - if (key_columns[i]->isColumnNullable()) + /// Ordinary '=' keys join only nested values where all components are not NULL. + /// NullEQ keys must keep their nullable wrapper so nullness can participate in key comparison. + if (key_columns[i]->isColumnNullable() && (is_null_eq.empty() || is_null_eq[i] == 0)) key_columns[i] = &static_cast(*key_columns[i]).getNestedColumn(); } return key_columns; } + +bool hasNullableNullEqKey(const Names & key_names, const Block & block, const std::vector & is_null_eq) +{ + RUNTIME_CHECK(key_names.size() == is_null_eq.size()); + for (size_t i = 0; i < key_names.size(); ++i) + { + if (is_null_eq[i] != 0 && block.getByName(key_names[i]).type->isNullable()) + return true; + } + return false; +} + size_t getRestoreJoinBuildConcurrency( size_t total_partitions, size_t spilled_partitions, @@ -96,6 +110,31 @@ size_t getRestoreJoinBuildConcurrency( } } +String formatNullEqFlags(const std::vector & flags) +{ + String result; + result.reserve(flags.size() * 2 + 2); + result += "["; + for (size_t i = 0; i < flags.size(); ++i) + { + if (i != 0) + result += ","; + result += flags[i] == 0 ? "0" : "1"; + } + result += "]"; + return result; +} + +bool hasNullEqKey(const std::vector & flags) +{ + for (auto flag : flags) + { + if (flag != 0) + return true; + } + return false; +} + } // namespace using PointerHelper = PointerTypeColumnHelper; @@ -111,6 +150,7 @@ const size_t MAX_RESTORE_ROUND_IN_GTEST = 2; Join::Join( const Names & key_names_left_, const Names & key_names_right_, + const std::vector & is_null_eq_, ASTTableJoin::Kind kind_, const String & req_id, size_t fine_grained_shuffle_count_, @@ -140,6 +180,7 @@ Join::Join( , may_probe_side_expanded_after_join(mayProbeSideExpandedAfterJoin(kind)) , key_names_left(key_names_left_) , key_names_right(key_names_right_) + , is_null_eq(is_null_eq_) , build_concurrency(0) , active_build_threads(0) , probe_concurrency(0) @@ -202,9 +243,11 @@ Join::Join( LOG_DEBUG( log, - "FineGrainedShuffle flag {}, stream count {}", + "FineGrainedShuffle flag {}, stream count {}, has_null_eq_key {}, is_null_eq {}", enable_fine_grained_shuffle, - fine_grained_shuffle_count); + fine_grained_shuffle_count, + hasNullEqKey(is_null_eq), + formatNullEqFlags(is_null_eq)); } void Join::meetError(const String & error_message_) @@ -357,6 +400,7 @@ std::shared_ptr Join::createRestoreJoin(size_t max_bytes_before_external_j auto ret = std::make_shared( key_names_left, key_names_right, + is_null_eq, kind, join_req_id, /// restore join never enable fine grained shuffle @@ -395,7 +439,16 @@ void Join::initBuild(const Block & sample_block, size_t build_concurrency_) if (unlikely(initialized)) throw Exception("Logical error: Join has been initialized", ErrorCodes::LOGICAL_ERROR); initialized = true; - join_map_method = chooseJoinMapMethod(getKeyColumns(key_names_right, sample_block), key_sizes, collators); + join_map_method = chooseJoinMapMethod( + getKeyColumns(key_names_right, sample_block, is_null_eq), + key_sizes, + collators, + is_null_eq); + if (hasNullableNullEqKey(key_names_right, sample_block, is_null_eq)) + { + if (join_map_method == JoinMapMethod::serialized) + LOG_DEBUG(log, "Use serialized join map method because nullable NullEQ keys do not fit packed fixed keys"); + } build_sample_block = sample_block; setBuildConcurrencyAndInitJoinPartition(build_concurrency_); hash_join_spill_context->init(build_concurrency); @@ -675,13 +728,12 @@ void Join::insertFromBlockInternal(Block * stored_block, size_t stream_index) } } - /// We will insert to the map only keys, where all components are not NULL. - ColumnPtr null_map_holder; - ConstNullMapPtr null_map{}; - extractNestedColumnsAndNullMap(key_columns, null_map_holder, null_map); - /// Reuse null_map to record the filtered rows, the rows contains NULL or does not - /// match the join filter will not insert to the maps - recordFilteredRows(block, non_equal_conditions.right_filter_column, null_map_holder, null_map); + /// Build a unified row filter map: ordinary '=' key NULLs and side-condition failures skip insertion, + /// while NullEQ key NULLs remain eligible for matching. + ColumnPtr row_filter_map_holder; + ConstNullMapPtr row_filter_map{}; + extractJoinKeyColumnsAndFilterNullMap(key_columns, is_null_eq, row_filter_map_holder, row_filter_map); + recordFilteredRows(block, non_equal_conditions.right_filter_column, row_filter_map_holder, row_filter_map); size_t size = stored_block->columns(); @@ -716,7 +768,7 @@ void Join::insertFromBlockInternal(Block * stored_block, size_t stream_index) key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, getBuildConcurrency(), enable_fine_grained_shuffle, @@ -965,7 +1017,7 @@ void Join::handleOtherConditions(Block & block, IColumn::Filter * anti_filter, I mergeNullAndFilterResult(block, filter, non_equal_conditions.other_eq_cond_from_in_name, isAntiJoin(kind)); assert(block_rows == filter.size()); - if (isInnerJoin(kind) || isNecessaryKindToUseRowFlaggedHashMap(kind)) + if (isInnerJoin(kind) || (isNecessaryKindToUseRowFlaggedHashMap(kind) && kind != ASTTableJoin::Kind::Full)) { erase_useless_column(block); /// inner | rightSemi | rightAnti | rightOuter join, just use other_filter_column to filter result @@ -974,6 +1026,17 @@ void Join::handleOtherConditions(Block & block, IColumn::Filter * anti_filter, I return; } + PointerHelper::ArrayType * full_join_mapped_entries = nullptr; + if (kind == ASTTableJoin::Kind::Full) + { + RUNTIME_CHECK(!flag_mapped_entry_helper_name.empty()); + auto & mapped_column = block.getByName(flag_mapped_entry_helper_name).column; + auto mutable_mapped_column = (*std::move(mapped_column)).mutate(); + auto & ptr_col = static_cast(*mutable_mapped_column); + full_join_mapped_entries = &ptr_col.getData(); + mapped_column = std::move(mutable_mapped_column); + } + bool is_semi_family = isSemiFamily(kind) || isLeftOuterSemiFamily(kind); for (size_t i = 0, prev_offset = 0; i < offsets_to_replicate->size(); ++i) { @@ -996,8 +1059,12 @@ void Join::handleOtherConditions(Block & block, IColumn::Filter * anti_filter, I if (prev_offset < current_offset) { /// for outer join, at least one row must be kept - if (isLeftOuterJoin(kind) && !has_row_kept) + if ((isLeftOuterJoin(kind) || kind == ASTTableJoin::Kind::Full) && !has_row_kept) + { row_filter[prev_offset] = 1; + if (full_join_mapped_entries != nullptr) + (*full_join_mapped_entries)[prev_offset] = 0; + } if (isAntiJoin(kind)) { if (has_row_kept && !(*anti_filter)[i]) @@ -1014,9 +1081,9 @@ void Join::handleOtherConditions(Block & block, IColumn::Filter * anti_filter, I prev_offset = current_offset; } erase_useless_column(block); - if (isLeftOuterJoin(kind)) + if (isLeftOuterJoin(kind) || kind == ASTTableJoin::Kind::Full) { - /// for left join, convert right column to null if not joined + /// for left/full join, convert right column to null if not joined applyNullToNotMatchedRows(block, right_sample_block, *filter_column); for (size_t i = 0; i < block.columns(); ++i) block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1); @@ -1268,7 +1335,7 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui probe_process_info.hash_join_data->key_columns, key_sizes, added_columns, - probe_process_info.null_map, + probe_process_info.row_filter_map, current_offset, offsets_to_replicate, right_indexes, @@ -1331,6 +1398,8 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui for (size_t i = 0; i < block.rows(); ++i) { auto ptr_value = container[i]; + if (ptr_value == 0) + continue; auto * current = reinterpret_cast(ptr_value); current->setUsed(); } @@ -1340,7 +1409,7 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info, const JoinBui // Return build table header for right semi/anti join block = right_sample_block; } - else if (kind == ASTTableJoin::Kind::RightOuter) + else if (kind == ASTTableJoin::Kind::RightOuter || kind == ASTTableJoin::Kind::Full) { block.erase(flag_mapped_entry_helper_name); } @@ -1378,6 +1447,7 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const restore_config.restore_round}; probe_process_info.prepareForHashProbe( key_names_left, + is_null_eq, non_equal_conditions.left_filter_column, kind, strictness, @@ -1556,8 +1626,8 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in null_rows[i] = partitions[i]->getRowsNotInsertedToMap(); NALeftSideInfo left_side_info( - probe_process_info.null_map, - probe_process_info.null_aware_join_data->filter_map, + probe_process_info.null_aware_join_data->key_null_map, + probe_process_info.row_filter_map, probe_process_info.null_aware_join_data->all_key_null_map); NARightSideInfo right_side_info( right_has_all_key_null_row.load(std::memory_order_relaxed), @@ -1714,6 +1784,7 @@ Block Join::joinBlockSemi(ProbeProcessInfo & probe_process_info) const probe_process_info.prepareForHashProbe( key_names_left, + is_null_eq, non_equal_conditions.left_filter_column, kind, strictness, diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 68a4edb6e65..9a1c7b72d57 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -152,9 +152,8 @@ using OneTimeNotifyFuturePtr = std::shared_ptr; * * How Nullable keys are processed: * - * NULLs never join to anything, even to each other. - * During building of map, we just skip keys with NULL value of any component. - * During joining, we simply treat rows with any NULLs in key as non joined. + * For ordinary '=' keys, rows with NULL in any key component are filtered before build/probe. + * For NullEQ keys, NULL is allowed to participate in key comparison. * * Default values for outer joins (LEFT, RIGHT, FULL): * @@ -168,6 +167,7 @@ class Join Join( const Names & key_names_left_, const Names & key_names_right_, + const std::vector & is_null_eq_, ASTTableJoin::Kind kind_, const String & req_id, size_t fine_grained_shuffle_count_, @@ -251,8 +251,10 @@ class Join size_t getTotalBuildInputRows() const { return total_input_build_rows; } ASTTableJoin::Kind getKind() const { return kind; } + JoinMapMethod getJoinMapMethod() const { return join_map_method; } const Names & getLeftJoinKeys() const { return key_names_left; } + const std::vector & getNullEqFlags() const { return is_null_eq; } void setInitActiveBuildThreads() { @@ -350,6 +352,8 @@ class Join const Names key_names_left; /// Names of key columns (columns for equi-JOIN) in "right" table (in the order they appear in USING clause). const Names key_names_right; + /// Per join-key-pair null-safe-equal flags, aligned with key_names_left/key_names_right. + const std::vector is_null_eq; mutable std::mutex build_probe_mutex; diff --git a/dbms/src/Interpreters/JoinHashMap.cpp b/dbms/src/Interpreters/JoinHashMap.cpp index 153b562a0f0..51f448093be 100644 --- a/dbms/src/Interpreters/JoinHashMap.cpp +++ b/dbms/src/Interpreters/JoinHashMap.cpp @@ -14,10 +14,14 @@ #include #include +#include #include #include +#include #include +#include + namespace DB { namespace @@ -33,31 +37,66 @@ bool canAsColumnString(const IColumn * column) JoinMapMethod chooseJoinMapMethod( const ColumnRawPtrs & key_columns, Sizes & key_sizes, - const TiDB::TiDBCollators & collators) + const TiDB::TiDBCollators & collators, + const std::vector & is_null_eq) { const size_t keys_size = key_columns.size(); + RUNTIME_CHECK(is_null_eq.empty() || is_null_eq.size() == keys_size); if (keys_size == 0) return JoinMapMethod::CROSS; + ColumnRawPtrs nested_key_columns; + nested_key_columns.reserve(keys_size); + bool has_nullable_null_eq_key = false; + for (size_t j = 0; j < keys_size; ++j) + { + const auto * key_column = key_columns[j]; + if (const auto * nullable_column = typeid_cast(key_column)) + { + nested_key_columns.push_back(&nullable_column->getNestedColumn()); + has_nullable_null_eq_key = has_nullable_null_eq_key || (!is_null_eq.empty() && is_null_eq[j] != 0); + } + else + { + nested_key_columns.push_back(key_column); + } + } + bool all_fixed = true; size_t keys_bytes = 0; key_sizes.resize(keys_size); for (size_t j = 0; j < keys_size; ++j) { - if (!key_columns[j]->isFixedAndContiguous()) + if (!nested_key_columns[j]->isFixedAndContiguous()) { all_fixed = false; break; } - key_sizes[j] = key_columns[j]->sizeOfValueIfFixed(); + key_sizes[j] = nested_key_columns[j]->sizeOfValueIfFixed(); keys_bytes += key_sizes[j]; } + if (has_nullable_null_eq_key) + { + if (all_fixed) + { + if (keys_bytes > (std::numeric_limits::max() - std::tuple_size>::value)) + throw Exception("Join: keys sizes overflow", ErrorCodes::LOGICAL_ERROR); + + if (std::tuple_size>::value + keys_bytes <= sizeof(UInt128)) + return JoinMapMethod::nullable_keys128; + if (std::tuple_size>::value + keys_bytes <= sizeof(UInt256)) + return JoinMapMethod::nullable_keys256; + } + + return JoinMapMethod::serialized; + } + /// If there is one numeric key that fits in 64 bits - if (keys_size == 1 && key_columns[0]->isNumeric()) + if (keys_size == 1 && nested_key_columns[0]->isNumeric()) { - size_t size_of_field = key_columns[0]->sizeOfValueIfFixed(); + size_t size_of_field = nested_key_columns[0]->sizeOfValueIfFixed(); if (size_of_field == 1) return JoinMapMethod::key8; if (size_of_field == 2) @@ -80,7 +119,7 @@ JoinMapMethod chooseJoinMapMethod( return JoinMapMethod::keys256; /// If there is single string key, use hash table of it's values. - if (keys_size == 1 && canAsColumnString(key_columns[0])) + if (keys_size == 1 && canAsColumnString(nested_key_columns[0])) { if (collators.empty() || !collators[0]) return JoinMapMethod::key_strbin; @@ -108,7 +147,7 @@ JoinMapMethod chooseJoinMapMethod( } } - if (keys_size == 1 && typeid_cast(key_columns[0])) + if (keys_size == 1 && typeid_cast(nested_key_columns[0])) return JoinMapMethod::key_fixed_string; /// Otherwise, use serialized values as the key. diff --git a/dbms/src/Interpreters/JoinHashMap.h b/dbms/src/Interpreters/JoinHashMap.h index d7a37355362..b5f0a975766 100644 --- a/dbms/src/Interpreters/JoinHashMap.h +++ b/dbms/src/Interpreters/JoinHashMap.h @@ -143,6 +143,8 @@ struct WithUsedFlag : Base M(key_fixed_string) \ M(keys128) \ M(keys256) \ + M(nullable_keys128) \ + M(nullable_keys256) \ M(serialized) enum class JoinMapMethod @@ -171,6 +173,8 @@ struct ConcurrentMapsTemplate using key_fixed_stringType = ConcurrentHashMapWithSavedHash; using keys128Type = ConcurrentHashMap>; using keys256Type = ConcurrentHashMap>; + using nullable_keys128Type = ConcurrentHashMap>; + using nullable_keys256Type = ConcurrentHashMap>; using serializedType = ConcurrentHashMapWithSavedHash; std::unique_ptr key8; @@ -183,6 +187,8 @@ struct ConcurrentMapsTemplate std::unique_ptr key_fixed_string; std::unique_ptr keys128; std::unique_ptr keys256; + std::unique_ptr nullable_keys128; + std::unique_ptr nullable_keys256; std::unique_ptr serialized; // TODO: add more cases like Aggregator }; @@ -201,6 +207,8 @@ struct MapsTemplate using key_fixed_stringType = HashMapWithSavedHash; using keys128Type = HashMap>; using keys256Type = HashMap>; + using nullable_keys128Type = HashMap>; + using nullable_keys256Type = HashMap>; using serializedType = HashMapWithSavedHash; std::unique_ptr key8; @@ -213,6 +221,8 @@ struct MapsTemplate std::unique_ptr key_fixed_string; std::unique_ptr keys128; std::unique_ptr keys256; + std::unique_ptr nullable_keys128; + std::unique_ptr nullable_keys256; std::unique_ptr serialized; // TODO: add more cases like Aggregator }; @@ -230,6 +240,8 @@ struct MapsAny using key_fixed_stringType = HashSetWithSavedHash; using keys128Type = HashSet>; using keys256Type = HashSet>; + using nullable_keys128Type = HashSet>; + using nullable_keys256Type = HashSet>; using serializedType = HashSetWithSavedHash; std::unique_ptr key8; @@ -242,6 +254,8 @@ struct MapsAny std::unique_ptr key_fixed_string; std::unique_ptr keys128; std::unique_ptr keys256; + std::unique_ptr nullable_keys128; + std::unique_ptr nullable_keys256; std::unique_ptr serialized; // TODO: add more cases like Aggregator }; @@ -257,5 +271,6 @@ using MapsAllFullWithRowFlag = MapsTemplate; // With fla JoinMapMethod chooseJoinMapMethod( const ColumnRawPtrs & key_columns, Sizes & key_sizes, - const TiDB::TiDBCollators & collators); + const TiDB::TiDBCollators & collators, + const std::vector & is_null_eq = {}); } // namespace DB diff --git a/dbms/src/Interpreters/JoinPartition.cpp b/dbms/src/Interpreters/JoinPartition.cpp index fc88a95a6a6..7d86a95bcd3 100644 --- a/dbms/src/Interpreters/JoinPartition.cpp +++ b/dbms/src/Interpreters/JoinPartition.cpp @@ -440,6 +440,16 @@ struct KeyGetterForTypeImpl using Type = ColumnsHashing::HashMethodKeysFixed; }; template +struct KeyGetterForTypeImpl +{ + using Type = ColumnsHashing::HashMethodKeysFixed; +}; +template +struct KeyGetterForTypeImpl +{ + using Type = ColumnsHashing::HashMethodKeysFixed; +}; +template struct KeyGetterForTypeImpl { using Type = ColumnsHashing::HashMethodSerialized; @@ -526,7 +536,7 @@ template < ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map, - bool has_null_map, + bool has_row_filter_map, bool need_record_not_insert_rows> void NO_INLINE insertBlockIntoMapTypeCase( JoinPartition & join_partition, @@ -535,7 +545,7 @@ void NO_INLINE insertBlockIntoMapTypeCase( const Sizes & key_sizes, const TiDB::TiDBCollators & collators, Block * stored_block, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, RowsNotInsertToMap * rows_not_inserted_to_map, size_t probe_cache_column_threshold) { @@ -549,13 +559,14 @@ void NO_INLINE insertBlockIntoMapTypeCase( bool null_need_materialize = isNullAwareSemiFamily(join_partition.getJoinKind()); for (size_t i = 0; i < rows; ++i) { - if constexpr (has_null_map) + if constexpr (has_row_filter_map) { - if ((*null_map)[i]) + if ((*row_filter_map)[i]) { if constexpr (need_record_not_insert_rows) { - /// for right/full out join or null-aware semi join, need to insert into rows_not_inserted_to_map + /// For right/full outer join or null-aware semi join, rows filtered before hash-map insertion + /// still need to be preserved in rows_not_inserted_to_map. rows_not_inserted_to_map->insertRow(stored_block, i, null_need_materialize, pool); } continue; @@ -578,7 +589,7 @@ template < ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map, - bool has_null_map, + bool has_row_filter_map, bool need_record_not_insert_rows> void NO_INLINE insertBlockIntoMapsTypeCase( JoinPartitions & join_partitions, @@ -587,7 +598,7 @@ void NO_INLINE insertBlockIntoMapsTypeCase( const Sizes & key_sizes, const TiDB::TiDBCollators & collators, Block * stored_block, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, size_t stream_index, RowsNotInsertToMap * rows_not_inserted_to_map, size_t probe_cache_column_threshold) @@ -608,7 +619,7 @@ void NO_INLINE insertBlockIntoMapsTypeCase( /// 2. hash value is calculated twice, maybe we can refine the code to cache the hash value /// 3. extra memory to store the segment index info std::vector> segment_index_info; - if constexpr (has_null_map && need_record_not_insert_rows) + if constexpr (has_row_filter_map && need_record_not_insert_rows) { segment_index_info.resize(segment_size + 1); } @@ -623,9 +634,9 @@ void NO_INLINE insertBlockIntoMapsTypeCase( } for (size_t i = 0; i < rows; ++i) { - if constexpr (has_null_map) + if constexpr (has_row_filter_map) { - if ((*null_map)[i]) + if ((*row_filter_map)[i]) { if constexpr (need_record_not_insert_rows) segment_index_info.back().push_back(i); @@ -734,7 +745,7 @@ void insertBlockIntoMapsImplType( const Sizes & key_sizes, const TiDB::TiDBCollators & collators, Block * stored_block, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, size_t stream_index, size_t insert_concurrency, bool enable_fine_grained_shuffle, @@ -746,7 +757,7 @@ void insertBlockIntoMapsImplType( if (enable_join_spill) { /// case 1, join with spill support, the partition level lock is acquired in `Join::insertFromBlock` - if (null_map) + if (row_filter_map) { if (rows_not_inserted_to_map) insertBlockIntoMapTypeCase( @@ -756,7 +767,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, rows_not_inserted_to_map, probe_cache_column_threshold); else @@ -767,7 +778,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, nullptr, probe_cache_column_threshold); } @@ -780,7 +791,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, nullptr, probe_cache_column_threshold); } @@ -789,7 +800,7 @@ void insertBlockIntoMapsImplType( else if (enable_fine_grained_shuffle) { /// case 2, join with fine_grained_shuffle, no need to acquire any lock - if (null_map) + if (row_filter_map) { if (rows_not_inserted_to_map) insertBlockIntoMapTypeCase( @@ -799,7 +810,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, rows_not_inserted_to_map, probe_cache_column_threshold); else @@ -810,7 +821,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, nullptr, probe_cache_column_threshold); } @@ -823,7 +834,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, nullptr, probe_cache_column_threshold); } @@ -831,7 +842,7 @@ void insertBlockIntoMapsImplType( else if (insert_concurrency > 1) { /// case 3, normal join with concurrency > 1, will acquire lock in `insertBlockIntoMapsTypeCase` - if (null_map) + if (row_filter_map) { if (rows_not_inserted_to_map) insertBlockIntoMapsTypeCase( @@ -841,7 +852,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, rows_not_inserted_to_map, probe_cache_column_threshold); @@ -853,7 +864,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, nullptr, probe_cache_column_threshold); @@ -867,7 +878,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, nullptr, probe_cache_column_threshold); @@ -877,7 +888,7 @@ void insertBlockIntoMapsImplType( { /// case 4, normal join with concurrency == 1, no need to acquire any lock RUNTIME_CHECK(stream_index == 0); - if (null_map) + if (row_filter_map) { if (rows_not_inserted_to_map) insertBlockIntoMapTypeCase( @@ -887,7 +898,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, rows_not_inserted_to_map, probe_cache_column_threshold); else @@ -898,7 +909,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, nullptr, probe_cache_column_threshold); } @@ -911,7 +922,7 @@ void insertBlockIntoMapsImplType( key_sizes, collators, stored_block, - null_map, + row_filter_map, nullptr, probe_cache_column_threshold); } @@ -926,7 +937,7 @@ void insertBlockIntoMapsImpl( const Sizes & key_sizes, const TiDB::TiDBCollators & collators, Block * stored_block, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, size_t stream_index, size_t insert_concurrency, bool enable_fine_grained_shuffle, @@ -952,7 +963,7 @@ void insertBlockIntoMapsImpl( key_sizes, \ collators, \ stored_block, \ - null_map, \ + row_filter_map, \ stream_index, \ insert_concurrency, \ enable_fine_grained_shuffle, \ @@ -1000,7 +1011,7 @@ void JoinPartition::insertBlockIntoMaps( const std::vector & key_sizes, const TiDB::TiDBCollators & collators, Block * stored_block, - ConstNullMapPtr & null_map, + ConstNullMapPtr & row_filter_map, size_t stream_index, size_t insert_concurrency, bool enable_fine_grained_shuffle, @@ -1020,7 +1031,7 @@ void JoinPartition::insertBlockIntoMaps( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, insert_concurrency, enable_fine_grained_shuffle, @@ -1034,7 +1045,7 @@ void JoinPartition::insertBlockIntoMaps( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, insert_concurrency, enable_fine_grained_shuffle, @@ -1051,7 +1062,7 @@ void JoinPartition::insertBlockIntoMaps( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, insert_concurrency, enable_fine_grained_shuffle, @@ -1065,7 +1076,7 @@ void JoinPartition::insertBlockIntoMaps( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, insert_concurrency, enable_fine_grained_shuffle, @@ -1082,7 +1093,7 @@ void JoinPartition::insertBlockIntoMaps( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, insert_concurrency, enable_fine_grained_shuffle, @@ -1099,7 +1110,7 @@ void JoinPartition::insertBlockIntoMaps( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, insert_concurrency, enable_fine_grained_shuffle, @@ -1113,7 +1124,7 @@ void JoinPartition::insertBlockIntoMaps( key_sizes, collators, stored_block, - null_map, + row_filter_map, stream_index, insert_concurrency, enable_fine_grained_shuffle, @@ -1452,6 +1463,28 @@ struct RowFlaggedHashMapAdder (*offsets)[i] = current_offset; return false; } + + static bool addNotFoundForFull( + size_t num_columns_to_add, + MutableColumns & added_columns, + size_t i, + IColumn::Offset & current_offset, + IColumn::Offsets * offsets, + ProbeProcessInfo & probe_process_info) + { + assert(num_columns_to_add + 1 == added_columns.size()); + if (current_offset && current_offset + 1 > probe_process_info.max_block_size) + return true; + + ++current_offset; + (*offsets)[i] = current_offset; + for (size_t j = 0; j < num_columns_to_add; ++j) + added_columns[j]->insertDefault(); + + auto & actual_ptr_col = static_cast(*added_columns[num_columns_to_add]); + actual_ptr_col.getData().push_back(0); + return false; + } }; template < @@ -1459,7 +1492,7 @@ template < ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map, - bool has_null_map, + bool has_row_filter_map, bool row_flagged_map> void NO_INLINE probeBlockImplTypeCase( const JoinPartitions & join_partitions, @@ -1467,7 +1500,7 @@ void NO_INLINE probeBlockImplTypeCase( const ColumnRawPtrs & key_columns, const Sizes & key_sizes, MutableColumns & added_columns, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, IColumn::Offset & current_offset, std::unique_ptr & offsets_to_replicate, const std::vector & right_indexes, @@ -1510,11 +1543,25 @@ void NO_INLINE probeBlockImplTypeCase( bool block_full = false; for (i = probe_process_info.start_row; i < rows; ++i) { - if (has_null_map && (*null_map)[i]) + if (has_row_filter_map && (*row_filter_map)[i]) { if constexpr (row_flagged_map) { - block_full = RowFlaggedHashMapAdder::addNotFound(i, current_offset, offsets_to_replicate.get()); + if constexpr (KIND == ASTTableJoin::Kind::Full) + { + block_full = RowFlaggedHashMapAdder::addNotFoundForFull( + num_columns_to_add, + added_columns, + i, + current_offset, + offsets_to_replicate.get(), + probe_process_info); + } + else + { + block_full + = RowFlaggedHashMapAdder::addNotFound(i, current_offset, offsets_to_replicate.get()); + } } /// RightSemi/RightAnti without other conditions, just ignore not matched probe rows else if constexpr (KIND != ASTTableJoin::Kind::RightSemi && KIND != ASTTableJoin::Kind::RightAnti) @@ -1615,8 +1662,21 @@ void NO_INLINE probeBlockImplTypeCase( { if constexpr (row_flagged_map) { - block_full - = RowFlaggedHashMapAdder::addNotFound(i, current_offset, offsets_to_replicate.get()); + if constexpr (KIND == ASTTableJoin::Kind::Full) + { + block_full = RowFlaggedHashMapAdder::addNotFoundForFull( + num_columns_to_add, + added_columns, + i, + current_offset, + offsets_to_replicate.get(), + probe_process_info); + } + else + { + block_full + = RowFlaggedHashMapAdder::addNotFound(i, current_offset, offsets_to_replicate.get()); + } } /// RightSemi/RightAnti without other conditions, just ignore not matched probe rows else if constexpr (KIND != ASTTableJoin::Kind::RightSemi && KIND != ASTTableJoin::Kind::RightAnti) @@ -1654,7 +1714,7 @@ void probeBlockImplType( const ColumnRawPtrs & key_columns, const Sizes & key_sizes, MutableColumns & added_columns, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, IColumn::Offset & current_offset, std::unique_ptr & offsets_to_replicate, const std::vector & right_indexes, @@ -1662,22 +1722,22 @@ void probeBlockImplType( const JoinBuildInfo & join_build_info, ProbeProcessInfo & probe_process_info) { -#define CALL(has_null_map) \ - probeBlockImplTypeCase( \ - join_partitions, \ - rows, \ - key_columns, \ - key_sizes, \ - added_columns, \ - null_map, \ - current_offset, \ - offsets_to_replicate, \ - right_indexes, \ - collators, \ - join_build_info, \ +#define CALL(has_row_filter_map) \ + probeBlockImplTypeCase( \ + join_partitions, \ + rows, \ + key_columns, \ + key_sizes, \ + added_columns, \ + row_filter_map, \ + current_offset, \ + offsets_to_replicate, \ + right_indexes, \ + collators, \ + join_build_info, \ probe_process_info); - if (null_map) + if (row_filter_map) { CALL(true); } @@ -1693,8 +1753,8 @@ template < ASTTableJoin::Strictness STRICTNESS, typename KeyGetter, typename Map, - bool has_null_map, - bool has_filter_map> + bool has_key_null_map, + bool has_row_filter_map> std::pair>, std::list *>> NO_INLINE probeBlockNullAwareSemiInternal( const JoinPartitions & join_partitions, @@ -1724,9 +1784,9 @@ probeBlockNullAwareSemiInternal( /// the result if it's not left outer semi join. for (size_t i = 0; i < rows; ++i) { - if constexpr (has_filter_map) + if constexpr (has_row_filter_map) { - if ((*left_side_info.filter_map)[i]) + if ((*left_side_info.row_filter_map)[i]) { /// Filter out by left_conditions so the result set is empty. res.emplace_back(i, NASemiJoinStep::DONE, nullptr); @@ -1742,9 +1802,9 @@ probeBlockNullAwareSemiInternal( res.back().template setResult(); continue; } - if constexpr (has_null_map) + if constexpr (has_key_null_map) { - if ((*left_side_info.null_map)[i]) + if ((*left_side_info.key_null_map)[i]) { /// some key is null if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any) @@ -1852,19 +1912,19 @@ std::pair>, std::list( \ - join_partitions, \ - rows, \ - key_columns, \ - key_sizes, \ - collators, \ - left_side_info, \ +#define CALL(has_key_null_map, has_row_filter_map) \ + return probeBlockNullAwareSemiInternal( \ + join_partitions, \ + rows, \ + key_columns, \ + key_sizes, \ + collators, \ + left_side_info, \ right_side_info); - if (left_side_info.null_map) + if (left_side_info.key_null_map) { - if (left_side_info.filter_map) + if (left_side_info.row_filter_map) { CALL(true, true); } @@ -1875,7 +1935,7 @@ std::pair>, std::list + bool has_row_filter_map> std::pair>, std::list *>> NO_INLINE probeBlockSemiInternal( const JoinPartitions & join_partitions, @@ -1941,9 +2001,10 @@ probeBlockSemiInternal( const auto & build_hash_data = probe_process_info.hash_join_data->hash_data->getData(); for (size_t i = 0; i < rows; ++i) { - if constexpr (has_null_map) + if constexpr (has_row_filter_map) { - /// If key columns have null map, it means these key columns do not come from IN. + /// row_filter_map means these rows should not enter regular hash probing. + /// For semi-family joins, this covers ordinary '=' key NULLs and side-condition failures. /// For example: /// SQL: select * from t1 where t1.a not in (select t2.a from t2 where t1.b = t2.b) /// t1.a or t2.a can be null. @@ -1951,7 +2012,7 @@ probeBlockSemiInternal( /// and t1.a = t2.a as other condition from IN. /// SQL: select * from t1 where t1.a not in (select t2.a from t2), t1.a or t2.a can be null. /// If this SQL does not have t1.b = t2.b, null-aware anti semi join will be used. - if ((*probe_process_info.null_map)[i]) + if ((*probe_process_info.row_filter_map)[i]) { if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any) { @@ -2052,16 +2113,16 @@ std::pair>, std::list( \ - join_partitions, \ - rows, \ - key_sizes, \ - collators, \ - join_build_info, \ +#define CALL(has_row_filter_map) \ + return probeBlockSemiInternal( \ + join_partitions, \ + rows, \ + key_sizes, \ + collators, \ + join_build_info, \ probe_process_info); - if (probe_process_info.null_map) + if (probe_process_info.row_filter_map) { CALL(true); } @@ -2080,7 +2141,7 @@ void JoinPartition::probeBlock( const ColumnRawPtrs & key_columns, const std::vector & key_sizes, MutableColumns & added_columns, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, IColumn::Offset & current_offset, std::unique_ptr & offsets_to_replicate, const std::vector & right_indexes, @@ -2109,7 +2170,7 @@ void JoinPartition::probeBlock( key_columns, \ key_sizes, \ added_columns, \ - null_map, \ + row_filter_map, \ current_offset, \ offsets_to_replicate, \ right_indexes, \ @@ -2121,8 +2182,10 @@ void JoinPartition::probeBlock( CALL(Inner, All, MapsAll, false) else if (kind == LeftOuter && strictness == All) CALL(LeftOuter, All, MapsAll, false) - else if (kind == Full && strictness == All) + else if (kind == Full && strictness == All && !use_row_flagged_map) CALL(LeftOuter, All, MapsAllFull, false) + else if (kind == Full && strictness == All && use_row_flagged_map) + CALL(Full, All, MapsAllFullWithRowFlag, true) else if (kind == RightOuter && strictness == All && !use_row_flagged_map) CALL(Inner, All, MapsAllFull, false) else if (kind == RightOuter && strictness == All && use_row_flagged_map) @@ -2147,7 +2210,7 @@ void JoinPartition::probeBlockImpl( const ColumnRawPtrs & key_columns, const std::vector & key_sizes, MutableColumns & added_columns, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, IColumn::Offset & current_offset, std::unique_ptr & offsets_to_replicate, const std::vector & right_indexes, @@ -2172,7 +2235,7 @@ void JoinPartition::probeBlockImpl( key_columns, \ key_sizes, \ added_columns, \ - null_map, \ + row_filter_map, \ current_offset, \ offsets_to_replicate, \ right_indexes, \ diff --git a/dbms/src/Interpreters/JoinPartition.h b/dbms/src/Interpreters/JoinPartition.h index 899604565d1..9deb3208bb7 100644 --- a/dbms/src/Interpreters/JoinPartition.h +++ b/dbms/src/Interpreters/JoinPartition.h @@ -127,7 +127,7 @@ class JoinPartition return rows_not_inserted_to_map.get(); } return nullptr; - }; + } Blocks trySpillProbePartition() { std::unique_lock lock(partition_mutex); @@ -176,7 +176,7 @@ class JoinPartition const std::vector & key_sizes, const TiDB::TiDBCollators & collators, Block * stored_block, - ConstNullMapPtr & null_map, + ConstNullMapPtr & row_filter_map, size_t stream_index, size_t insert_concurrency, bool enable_fine_grained_shuffle, @@ -190,7 +190,7 @@ class JoinPartition const ColumnRawPtrs & key_columns, const std::vector & key_sizes, MutableColumns & added_columns, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, IColumn::Offset & current_offset, std::unique_ptr & offsets_to_replicate, const std::vector & right_indexes, @@ -205,7 +205,7 @@ class JoinPartition const ColumnRawPtrs & key_columns, const std::vector & key_sizes, MutableColumns & added_columns, - ConstNullMapPtr null_map, + ConstNullMapPtr row_filter_map, IColumn::Offset & current_offset, std::unique_ptr & offsets_to_replicate, const std::vector & right_indexes, diff --git a/dbms/src/Interpreters/JoinUtils.cpp b/dbms/src/Interpreters/JoinUtils.cpp index 70309b7b907..16859c2892c 100644 --- a/dbms/src/Interpreters/JoinUtils.cpp +++ b/dbms/src/Interpreters/JoinUtils.cpp @@ -40,6 +40,43 @@ ColumnRawPtrs extractAndMaterializeKeyColumns( return key_columns; } +void extractJoinKeyColumnsAndFilterNullMap( + ColumnRawPtrs & key_columns, + const std::vector & is_null_eq, + ColumnPtr & null_map_holder, + ConstNullMapPtr & null_map) +{ + RUNTIME_CHECK(key_columns.size() == is_null_eq.size()); + + for (size_t i = 0; i < key_columns.size(); ++i) + { + auto & column = key_columns[i]; + if (!column->isColumnNullable() || is_null_eq[i] != 0) + continue; + + const auto & column_nullable = static_cast(*column); + column = &column_nullable.getNestedColumn(); + + if (!null_map_holder) + { + null_map_holder = column_nullable.getNullMapColumnPtr(); + } + else + { + MutableColumnPtr mutable_null_map_holder = (*std::move(null_map_holder)).mutate(); + + PaddedPODArray & mutable_null_map = static_cast(*mutable_null_map_holder).getData(); + const PaddedPODArray & other_null_map = column_nullable.getNullMapData(); + for (size_t row = 0, size = mutable_null_map.size(); row < size; ++row) + mutable_null_map[row] |= other_null_map[row]; + + null_map_holder = std::move(mutable_null_map_holder); + } + } + + null_map = null_map_holder ? &static_cast(*null_map_holder).getData() : nullptr; +} + void recordFilteredRows( const Block & block, const String & filter_column, diff --git a/dbms/src/Interpreters/JoinUtils.h b/dbms/src/Interpreters/JoinUtils.h index 898741fc35b..bae4ad8fb8a 100644 --- a/dbms/src/Interpreters/JoinUtils.h +++ b/dbms/src/Interpreters/JoinUtils.h @@ -88,7 +88,7 @@ inline bool needScanHashMapAfterProbe(ASTTableJoin::Kind kind) inline bool isNecessaryKindToUseRowFlaggedHashMap(ASTTableJoin::Kind kind) { - return isRightSemiFamily(kind) || kind == ASTTableJoin::Kind::RightOuter; + return isRightSemiFamily(kind) || kind == ASTTableJoin::Kind::RightOuter || kind == ASTTableJoin::Kind::Full; } inline bool useRowFlaggedHashMap(ASTTableJoin::Kind kind, bool has_other_condition) @@ -142,6 +142,11 @@ ColumnRawPtrs extractAndMaterializeKeyColumns( const Block & block, Columns & materialized_columns, const Strings & key_columns_names); +void extractJoinKeyColumnsAndFilterNullMap( + ColumnRawPtrs & key_columns, + const std::vector & is_null_eq, + ColumnPtr & null_map_holder, + ConstNullMapPtr & null_map); void recordFilteredRows( const Block & block, const String & filter_column, diff --git a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h index 63d5a170513..1223346e42d 100644 --- a/dbms/src/Interpreters/NullAwareSemiJoinHelper.h +++ b/dbms/src/Interpreters/NullAwareSemiJoinHelper.h @@ -63,15 +63,18 @@ struct NARightSideInfo struct NALeftSideInfo { NALeftSideInfo( - const ConstNullMapPtr & null_map_, - const ConstNullMapPtr & filter_map_, + const ConstNullMapPtr & key_null_map_, + const ConstNullMapPtr & row_filter_map_, const ConstNullMapPtr & all_key_null_map_) - : null_map(null_map_) - , filter_map(filter_map_) + : key_null_map(key_null_map_) + , row_filter_map(row_filter_map_) , all_key_null_map(all_key_null_map_) {} - const ConstNullMapPtr & null_map; - const ConstNullMapPtr & filter_map; + /// Rows whose null-aware join key contains at least one NULL. + const ConstNullMapPtr & key_null_map; + /// Rows filtered out by side conditions before null-aware probing. + const ConstNullMapPtr & row_filter_map; + /// Rows whose null-aware join keys are all NULL. const ConstNullMapPtr & all_key_null_map; }; diff --git a/dbms/src/Interpreters/ProbeProcessInfo.cpp b/dbms/src/Interpreters/ProbeProcessInfo.cpp index 42f28f04901..f9853bab92a 100644 --- a/dbms/src/Interpreters/ProbeProcessInfo.cpp +++ b/dbms/src/Interpreters/ProbeProcessInfo.cpp @@ -32,8 +32,8 @@ void ProbeProcessInfo::resetBlock(Block && block_, size_t partition_index_) // min_result_block_size is used to avoid generating too many small block, use 50% of the block size as the default value min_result_block_size = std::max(1, (std::min(block.rows(), max_block_size) + 1) / 2); prepare_for_probe_done = false; - null_map = nullptr; - null_map_holder = nullptr; + row_filter_map = nullptr; + row_filter_map_holder = nullptr; filter.reset(); offsets_to_replicate.reset(); if (hash_join_data) @@ -46,6 +46,7 @@ void ProbeProcessInfo::resetBlock(Block && block_, size_t partition_index_) void ProbeProcessInfo::prepareForHashProbe( const Names & key_names, + const std::vector & is_null_eq, const String & filter_column, ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness, @@ -61,11 +62,14 @@ void ProbeProcessInfo::prepareForHashProbe( /// Note: this variable can't be removed because it will take smart pointers' lifecycle to the end of this function. hash_join_data->key_columns = extractAndMaterializeKeyColumns(block, hash_join_data->materialized_columns, key_names); - /// Keys with NULL value in any column won't join to anything. - extractNestedColumnsAndNullMap(hash_join_data->key_columns, null_map_holder, null_map); - /// reuse null_map to record the filtered rows, the rows contains NULL or does not - /// match the join filter won't join to anything - recordFilteredRows(block, filter_column, null_map_holder, null_map); + /// Build a unified row filter map: ordinary '=' key NULLs and side-condition failures skip probing, + /// while NullEQ key NULLs remain probeable. + extractJoinKeyColumnsAndFilterNullMap( + hash_join_data->key_columns, + is_null_eq, + row_filter_map_holder, + row_filter_map); + recordFilteredRows(block, filter_column, row_filter_map_holder, row_filter_map); size_t existing_columns = block.columns(); /** If you use FULL or RIGHT JOIN, then the columns from the "left" table must be materialized. @@ -121,7 +125,7 @@ void ProbeProcessInfo::prepareForCrossProbe( cross_join_data->cross_probe_mode = cross_probe_mode_; cross_join_data->right_block_size = right_block_size_; - recordFilteredRows(block, filter_column, null_map_holder, null_map); + recordFilteredRows(block, filter_column, row_filter_map_holder, row_filter_map); if (kind == ASTTableJoin::Kind::Cross_Anti && strictness == ASTTableJoin::Strictness::All) /// `CrossJoinAdder` will skip the matched rows directly, so filter is not needed filter = std::make_unique(block.rows()); @@ -157,8 +161,8 @@ void ProbeProcessInfo::prepareForCrossProbe( } } } - if (cross_join_data->cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK && null_map != nullptr) - cross_join_data->row_num_filtered_by_left_condition = countBytesInFilter(*null_map); + if (cross_join_data->cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK && row_filter_map != nullptr) + cross_join_data->row_num_filtered_by_left_condition = countBytesInFilter(*row_filter_map); prepare_for_probe_done = true; } @@ -179,9 +183,14 @@ void ProbeProcessInfo::prepareForNullAware(const Names & key_names, const String null_aware_join_data->all_key_null_map_holder, null_aware_join_data->all_key_null_map); - extractNestedColumnsAndNullMap(null_aware_join_data->key_columns, null_map_holder, null_map); + extractNestedColumnsAndNullMap( + null_aware_join_data->key_columns, + null_aware_join_data->key_null_map_holder, + null_aware_join_data->key_null_map); - recordFilteredRows(block, filter_column, null_aware_join_data->filter_map_holder, null_aware_join_data->filter_map); + // Reuse the generic probe-side row filter map, but for null-aware join it only records + // rows filtered out by side conditions. Key-null rows are tracked separately in key_null_map. + recordFilteredRows(block, filter_column, row_filter_map_holder, row_filter_map); prepare_for_probe_done = true; } diff --git a/dbms/src/Interpreters/ProbeProcessInfo.h b/dbms/src/Interpreters/ProbeProcessInfo.h index a45d9a4b3aa..bd7841b64db 100644 --- a/dbms/src/Interpreters/ProbeProcessInfo.h +++ b/dbms/src/Interpreters/ProbeProcessInfo.h @@ -76,16 +76,18 @@ struct NullAwareJoinProbeProcessData { Columns materialized_columns; ColumnRawPtrs key_columns; - ColumnPtr filter_map_holder = nullptr; - ConstNullMapPtr filter_map = nullptr; + /// Rows where any null-aware join key is NULL. + ColumnPtr key_null_map_holder = nullptr; + ConstNullMapPtr key_null_map = nullptr; + /// Rows where all null-aware join keys are NULL. ColumnPtr all_key_null_map_holder = nullptr; ConstNullMapPtr all_key_null_map = nullptr; void reset() { key_columns.clear(); materialized_columns.clear(); - filter_map_holder = nullptr; - filter_map = nullptr; + key_null_map_holder = nullptr; + key_null_map = nullptr; all_key_null_map_holder = nullptr; all_key_null_map = nullptr; } @@ -104,8 +106,12 @@ struct ProbeProcessInfo /// these should be inited before probe each block bool prepare_for_probe_done = false; - ColumnPtr null_map_holder = nullptr; - ConstNullMapPtr null_map = nullptr; + /// Unified probe-side row filter map. + /// For regular hash/cross join, it contains ordinary '=' key NULLs plus side-condition failures. + /// For null-aware join, it contains side-condition failures only; key-null rows stay in + /// null_aware_join_data->key_null_map. + ColumnPtr row_filter_map_holder = nullptr; + ConstNullMapPtr row_filter_map = nullptr; /// Used with ANY INNER ANTI JOIN std::unique_ptr filter = nullptr; /// Used with ALL ... JOIN @@ -171,6 +177,7 @@ struct ProbeProcessInfo void prepareForHashProbe( const Names & key_names, + const std::vector & is_null_eq, const String & filter_column, ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness, diff --git a/dbms/src/Interpreters/tests/gtest_join_null_eq.cpp b/dbms/src/Interpreters/tests/gtest_join_null_eq.cpp new file mode 100644 index 00000000000..f3b2bc1b8ce --- /dev/null +++ b/dbms/src/Interpreters/tests/gtest_join_null_eq.cpp @@ -0,0 +1,1245 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB::tests +{ +namespace +{ +constexpr auto * test_key_name = "k"; +constexpr auto * probe_eq_key_name = "eq_k"; +constexpr auto * probe_null_eq_key_name = "null_eq_k"; +constexpr auto * probe_filter_name = "filter"; +constexpr auto * outer_probe_key_name = "probe_k"; +constexpr auto * outer_probe_value_name = "probe_v"; +constexpr auto * outer_probe_filter_name = "probe_filter"; +constexpr auto * outer_build_key_name = "build_k"; +constexpr auto * outer_build_value_name = "build_v"; +constexpr auto * outer_build_filter_name = "build_filter"; +constexpr auto * mixed_probe_key1_name = "probe_k1"; +constexpr auto * mixed_probe_key2_name = "probe_k2"; +constexpr auto * mixed_probe_value_name = "probe_multi_v"; +constexpr auto * mixed_build_key1_name = "build_k1"; +constexpr auto * mixed_build_key2_name = "build_k2"; +constexpr auto * mixed_build_value_name = "build_multi_v"; +constexpr auto * full_other_cond_name = "full_other_cond"; +constexpr auto * full_flag_helper_name = "__full_flag_helper"; + +Block makeOuterProbeSampleBlock(const DataTypePtr & key_type, bool include_filter); +Block makeOuterBuildSampleBlock(const DataTypePtr & key_type, bool include_filter); +void prepareAndFinalizeMixedJoin(const JoinPtr & join, const DataTypePtr & key_type); + +void ensureFunctionsRegistered() +{ + static std::once_flag once; + std::call_once(once, [] { + try + { + registerFunctions(); + } + catch (DB::Exception &) + { + // Another test suite may have already registered the functions. + } + }); +} + +Block makeSampleBlock(const DataTypePtr & key_type) +{ + return Block{{key_type->createColumn(), key_type, test_key_name}}; +} + +JoinPtr makeTestJoin(const DataTypePtr & key_type, const std::vector & is_null_eq) +{ + SpillConfig build_spill_config("/tmp", "join_null_eq_build", 0, 0, 0, nullptr); + SpillConfig probe_spill_config("/tmp", "join_null_eq_probe", 0, 0, 0, nullptr); + return std::make_shared( + Names{test_key_name}, + Names{test_key_name}, + is_null_eq, + ASTTableJoin::Kind::Inner, + "join_null_eq_test", + 0, + 0, + build_spill_config, + probe_spill_config, + RestoreConfig{1, 0, 0}, + NamesAndTypes{{test_key_name, key_type}}, + RegisterOperatorSpillContext{}, + nullptr, + TiDB::TiDBCollators{}, + JoinNonEqualConditions{}, + 1024, + 0, + "", + "", + 0, + true); +} + +JoinNonEqualConditions makeFullJoinOtherCondition() +{ + ensureFunctionsRegistered(); + + auto nullable_int_type = makeNullable(std::make_shared()); + auto actions = std::make_shared(NamesAndTypes{ + {outer_probe_key_name, nullable_int_type}, + {outer_probe_value_name, nullable_int_type}, + {outer_build_key_name, nullable_int_type}, + {outer_build_value_name, nullable_int_type}, + }); + auto equals_builder = FunctionFactory::instance().get("equals", *TiFlashTestEnv::getContext()); + actions->add(ExpressionAction::applyFunction( + equals_builder, + {outer_probe_value_name, outer_build_value_name}, + full_other_cond_name)); + + JoinNonEqualConditions conditions; + conditions.other_cond_name = full_other_cond_name; + conditions.other_cond_expr = actions; + return conditions; +} + +JoinNonEqualConditions makeOuterJoinSideConditions( + const String & left_filter_column = "", + const String & right_filter_column = "") +{ + JoinNonEqualConditions conditions; + conditions.left_filter_column = left_filter_column; + conditions.right_filter_column = right_filter_column; + return conditions; +} + +JoinPtr makeOuterJoinTestJoin( + ASTTableJoin::Kind kind, + const DataTypePtr & key_type, + const JoinNonEqualConditions & non_equal_conditions = JoinNonEqualConditions{}, + const String & flag_helper_name = "") +{ + auto nullable_value_type = makeNullable(std::make_shared()); + SpillConfig build_spill_config("/tmp", "join_null_eq_build", 0, 0, 0, nullptr); + SpillConfig probe_spill_config("/tmp", "join_null_eq_probe", 0, 0, 0, nullptr); + return std::make_shared( + Names{outer_probe_key_name}, + Names{outer_build_key_name}, + std::vector{1}, + kind, + "join_null_eq_outer_test", + 0, + 0, + build_spill_config, + probe_spill_config, + RestoreConfig{1, 0, 0}, + NamesAndTypes{ + {outer_probe_key_name, key_type}, + {outer_probe_value_name, nullable_value_type}, + {outer_build_key_name, key_type}, + {outer_build_value_name, nullable_value_type}, + }, + RegisterOperatorSpillContext{}, + nullptr, + TiDB::TiDBCollators{}, + non_equal_conditions, + 1024, + 0, + "", + flag_helper_name, + 0, + true); +} + +JoinPtr makeOuterJoinTestJoin( + ASTTableJoin::Kind kind, + const JoinNonEqualConditions & non_equal_conditions = JoinNonEqualConditions{}, + const String & flag_helper_name = "") +{ + return makeOuterJoinTestJoin( + kind, + makeNullable(std::make_shared()), + non_equal_conditions, + flag_helper_name); +} + +JoinPtr makeSemiJoinTestJoin(ASTTableJoin::Kind kind) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + SpillConfig build_spill_config("/tmp", "join_null_eq_build", 0, 0, 0, nullptr); + SpillConfig probe_spill_config("/tmp", "join_null_eq_probe", 0, 0, 0, nullptr); + return std::make_shared( + Names{outer_probe_key_name}, + Names{outer_build_key_name}, + std::vector{1}, + kind, + "join_null_eq_semi_test", + 0, + 0, + build_spill_config, + probe_spill_config, + RestoreConfig{1, 0, 0}, + NamesAndTypes{ + {outer_probe_key_name, nullable_int_type}, + {outer_probe_value_name, int_type}, + }, + RegisterOperatorSpillContext{}, + nullptr, + TiDB::TiDBCollators{}, + JoinNonEqualConditions{}, + 1024, + 0, + "", + "", + 0, + true); +} + +JoinPtr makeMixedKeyJoin(const std::vector & is_null_eq, const DataTypePtr & key_type) +{ + auto int_type = std::make_shared(); + SpillConfig build_spill_config("/tmp", "join_null_eq_build", 0, 0, 0, nullptr); + SpillConfig probe_spill_config("/tmp", "join_null_eq_probe", 0, 0, 0, nullptr); + return std::make_shared( + Names{mixed_probe_key1_name, mixed_probe_key2_name}, + Names{mixed_build_key1_name, mixed_build_key2_name}, + is_null_eq, + ASTTableJoin::Kind::Inner, + "join_null_eq_mixed_test", + 0, + 0, + build_spill_config, + probe_spill_config, + RestoreConfig{1, 0, 0}, + NamesAndTypes{ + {mixed_probe_key1_name, key_type}, + {mixed_probe_key2_name, key_type}, + {mixed_probe_value_name, int_type}, + {mixed_build_key1_name, key_type}, + {mixed_build_key2_name, key_type}, + {mixed_build_value_name, int_type}, + }, + RegisterOperatorSpillContext{}, + nullptr, + TiDB::TiDBCollators{}, + JoinNonEqualConditions{}, + 1024, + 0, + "", + "", + 0, + true); +} + +JoinPtr makeMixedKeyJoin(const std::vector & is_null_eq) +{ + return makeMixedKeyJoin(is_null_eq, makeNullable(std::make_shared())); +} + +Block makeOuterProbeSampleBlock(bool include_filter = false) +{ + return makeOuterProbeSampleBlock(makeNullable(std::make_shared()), include_filter); +} + +Block makeOuterProbeSampleBlock(const DataTypePtr & key_type, bool include_filter = false) +{ + auto int_type = std::make_shared(); + auto block = Block{ + {key_type->createColumn(), key_type, outer_probe_key_name}, + {int_type->createColumn(), int_type, outer_probe_value_name}, + }; + if (include_filter) + { + auto uint8_type = std::make_shared(); + block.insert({uint8_type->createColumn(), uint8_type, outer_probe_filter_name}); + } + return block; +} + +Block makeOuterBuildSampleBlock(bool include_filter = false) +{ + return makeOuterBuildSampleBlock(makeNullable(std::make_shared()), include_filter); +} + +Block makeOuterBuildSampleBlock(const DataTypePtr & key_type, bool include_filter = false) +{ + auto int_type = std::make_shared(); + auto block = Block{ + {key_type->createColumn(), key_type, outer_build_key_name}, + {int_type->createColumn(), int_type, outer_build_value_name}, + }; + if (include_filter) + { + auto uint8_type = std::make_shared(); + block.insert({uint8_type->createColumn(), uint8_type, outer_build_filter_name}); + } + return block; +} + +Block makeMixedProbeSampleBlock(const DataTypePtr & key_type) +{ + auto int_type = std::make_shared(); + return Block{ + {key_type->createColumn(), key_type, mixed_probe_key1_name}, + {key_type->createColumn(), key_type, mixed_probe_key2_name}, + {int_type->createColumn(), int_type, mixed_probe_value_name}, + }; +} + +Block makeMixedBuildSampleBlock(const DataTypePtr & key_type) +{ + auto int_type = std::make_shared(); + return Block{ + {key_type->createColumn(), key_type, mixed_build_key1_name}, + {key_type->createColumn(), key_type, mixed_build_key2_name}, + {int_type->createColumn(), int_type, mixed_build_value_name}, + }; +} + +template +ColumnPtr makeNullableNumberColumn(std::initializer_list> values) +{ + auto nested = ColumnType::create(); + auto null_map = ColumnUInt8::create(); + nested->reserve(values.size()); + null_map->reserve(values.size()); + auto & nested_data = nested->getData(); + auto & null_map_data = null_map->getData(); + for (const auto & value : values) + { + if (value.has_value()) + { + nested_data.push_back(*value); + null_map_data.push_back(0); + } + else + { + nested_data.push_back(0); + null_map_data.push_back(1); + } + } + return ColumnNullable::create(std::move(nested), std::move(null_map)); +} + +ColumnPtr makeNullableInt32Column(std::initializer_list> values) +{ + return makeNullableNumberColumn(values); +} + +ColumnPtr makeNullableInt64Column(std::initializer_list> values) +{ + return makeNullableNumberColumn(values); +} + +ColumnPtr makeNullableStringColumn(std::initializer_list> values) +{ + auto nested = ColumnString::create(); + auto null_map = ColumnUInt8::create(); + null_map->reserve(values.size()); + auto & null_map_data = null_map->getData(); + for (const auto & value : values) + { + if (value.has_value()) + { + nested->insertData(value->data(), value->size()); + null_map_data.push_back(0); + } + else + { + nested->insertData("", 0); + null_map_data.push_back(1); + } + } + return ColumnNullable::create(std::move(nested), std::move(null_map)); +} + +ColumnPtr makeNullableFixedStringColumn(size_t string_size, std::initializer_list> values) +{ + auto nested = ColumnFixedString::create(string_size); + auto null_map = ColumnUInt8::create(); + null_map->reserve(values.size()); + auto & null_map_data = null_map->getData(); + for (const auto & value : values) + { + if (value.has_value()) + { + nested->insertData(value->data(), value->size()); + null_map_data.push_back(0); + } + else + { + nested->insertData("", 0); + null_map_data.push_back(1); + } + } + return ColumnNullable::create(std::move(nested), std::move(null_map)); +} + +ColumnPtr makeUInt8Column(std::initializer_list values) +{ + auto column = ColumnUInt8::create(); + column->reserve(values.size()); + auto & data = column->getData(); + for (auto value : values) + data.push_back(value); + return column; +} + +ColumnPtr makeInt32Column(std::initializer_list values) +{ + auto column = ColumnInt32::create(); + column->reserve(values.size()); + auto & data = column->getData(); + for (auto value : values) + data.push_back(value); + return column; +} + +Block readAllBlocks(const BlockInputStreamPtr & stream) +{ + stream->readPrefix(); + Blocks blocks; + while (true) + { + auto block = stream->read(); + if (!block) + break; + blocks.push_back(std::move(block)); + } + stream->readSuffix(); + if (blocks.empty()) + return stream->getHeader().cloneEmpty(); + return vstackBlocks(std::move(blocks)); +} + +std::optional getInt32Value(const Block & block, const String & name, size_t row) +{ + const auto & column = block.getByName(name).column; + if (const auto * nullable_column = checkAndGetColumn(column.get()); nullable_column != nullptr) + { + if (nullable_column->getNullMapData()[row] != 0) + return std::nullopt; + return checkAndGetColumn(nullable_column->getNestedColumnPtr().get())->getData()[row]; + } + return checkAndGetColumn(column.get())->getData()[row]; +} + +void prepareAndFinalizeOuterJoin( + const JoinPtr & join, + bool include_probe_filter = false, + bool include_build_filter = false) +{ + join->initBuild(makeOuterBuildSampleBlock(include_build_filter), 1); + join->initProbe(makeOuterProbeSampleBlock(include_probe_filter), 1); + join->finalize(Names{ + outer_probe_key_name, + outer_probe_value_name, + outer_build_key_name, + outer_build_value_name, + }); +} + +void prepareAndFinalizeOuterJoin( + const JoinPtr & join, + const DataTypePtr & key_type, + bool include_probe_filter = false, + bool include_build_filter = false) +{ + join->initBuild(makeOuterBuildSampleBlock(key_type, include_build_filter), 1); + join->initProbe(makeOuterProbeSampleBlock(key_type, include_probe_filter), 1); + join->finalize(Names{ + outer_probe_key_name, + outer_probe_value_name, + outer_build_key_name, + outer_build_value_name, + }); +} + +void prepareAndFinalizeSemiJoin(const JoinPtr & join) +{ + join->initBuild(makeOuterBuildSampleBlock(), 1); + join->initProbe(makeOuterProbeSampleBlock(), 1); + join->finalize(Names{ + outer_probe_key_name, + outer_probe_value_name, + }); +} + +void prepareAndFinalizeMixedJoin(const JoinPtr & join) +{ + prepareAndFinalizeMixedJoin(join, makeNullable(std::make_shared())); +} + +void prepareAndFinalizeMixedJoin(const JoinPtr & join, const DataTypePtr & key_type) +{ + join->initBuild(makeMixedBuildSampleBlock(key_type), 1); + join->initProbe(makeMixedProbeSampleBlock(key_type), 1); + join->finalize(Names{ + mixed_probe_key1_name, + mixed_probe_key2_name, + mixed_probe_value_name, + mixed_build_key1_name, + mixed_build_key2_name, + mixed_build_value_name, + }); +} +} // namespace + +TEST(JoinNullEqTest, NullableNullEqKeyUsesNullablePackedJoinMapMethod) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto join = makeTestJoin(nullable_int_type, {1}); + join->initBuild(makeSampleBlock(nullable_int_type), 1); + + ASSERT_EQ(join->getJoinMapMethod(), JoinMapMethod::nullable_keys128); +} + +TEST(JoinNullEqTest, NullableMixedNullEqKeysCanUseNullableKeys256JoinMapMethod) +{ + auto nullable_int64_type = makeNullable(std::make_shared()); + auto join = makeMixedKeyJoin({1, 1}, nullable_int64_type); + join->initBuild(makeMixedBuildSampleBlock(nullable_int64_type), 1); + + ASSERT_EQ(join->getJoinMapMethod(), JoinMapMethod::nullable_keys256); +} + +TEST(JoinNullEqTest, NullableMixedNullEqKeys256JoinProducesJoinedRow) +{ + auto nullable_int64_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeMixedKeyJoin({1, 1}, nullable_int64_type); + prepareAndFinalizeMixedJoin(join, nullable_int64_type); + + ASSERT_EQ(join->getJoinMapMethod(), JoinMapMethod::nullable_keys256); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt64Column({std::nullopt}), nullable_int64_type, mixed_build_key1_name}, + {makeNullableInt64Column({11}), nullable_int64_type, mixed_build_key2_name}, + {makeInt32Column({100}), int_type, mixed_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt64Column({std::nullopt}), nullable_int64_type, mixed_probe_key1_name}, + {makeNullableInt64Column({11}), nullable_int64_type, mixed_probe_key2_name}, + {makeInt32Column({10}), int_type, mixed_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, NullableStringNullEqFallsBackToSerializedJoinMapMethod) +{ + auto nullable_string_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::Inner, nullable_string_type); + prepareAndFinalizeOuterJoin(join, nullable_string_type); + + ASSERT_EQ(join->getJoinMapMethod(), JoinMapMethod::serialized); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableStringColumn({std::nullopt, "alpha"}), nullable_string_type, outer_build_key_name}, + {makeInt32Column({100, 200}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableStringColumn({std::nullopt, "alpha", "beta"}), nullable_string_type, outer_probe_key_name}, + {makeInt32Column({10, 20, 30}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 2); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 1), 20); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 1), 200); +} + +TEST(JoinNullEqTest, OversizedNullableFixedKeysFallBackToSerializedJoinMapMethod) +{ + constexpr size_t fixed_string_size = 16; + auto nullable_fixed_string_type = makeNullable(std::make_shared(fixed_string_size)); + auto int_type = std::make_shared(); + auto join = makeMixedKeyJoin({1, 1}, nullable_fixed_string_type); + prepareAndFinalizeMixedJoin(join, nullable_fixed_string_type); + + ASSERT_EQ(join->getJoinMapMethod(), JoinMapMethod::serialized); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableFixedStringColumn(fixed_string_size, {std::nullopt}), + nullable_fixed_string_type, + mixed_build_key1_name}, + {makeNullableFixedStringColumn(fixed_string_size, {"abcdefghijklmnop"}), + nullable_fixed_string_type, + mixed_build_key2_name}, + {makeInt32Column({100}), int_type, mixed_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableFixedStringColumn(fixed_string_size, {std::nullopt, std::nullopt}), + nullable_fixed_string_type, + mixed_probe_key1_name}, + {makeNullableFixedStringColumn(fixed_string_size, {"abcdefghijklmnop", "qrstuvwxyzabcdef"}), + nullable_fixed_string_type, + mixed_probe_key2_name}, + {makeInt32Column({10, 20}), int_type, mixed_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, DefaultMethodSelectionRemainsForOtherCases) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto nullable_non_null_eq_join = makeTestJoin(nullable_int_type, {0}); + nullable_non_null_eq_join->initBuild(makeSampleBlock(nullable_int_type), 1); + ASSERT_EQ(nullable_non_null_eq_join->getJoinMapMethod(), JoinMapMethod::key32); + + auto int_type = std::make_shared(); + auto non_nullable_null_eq_join = makeTestJoin(int_type, {1}); + non_nullable_null_eq_join->initBuild(makeSampleBlock(int_type), 1); + ASSERT_EQ(non_nullable_null_eq_join->getJoinMapMethod(), JoinMapMethod::key32); +} + +TEST(JoinNullEqTest, NullableNullEqBuildRowsAreInsertedIntoHashMap) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto join = makeTestJoin(nullable_int_type, {1}); + join->initBuild(makeSampleBlock(nullable_int_type), 1); + + Block build_block{ + {makeNullableInt32Column({std::nullopt, 7}), nullable_int_type, test_key_name}, + }; + join->insertFromBlock(build_block, 0); + + ASSERT_EQ(join->getTotalRowCount(), 2); +} + +TEST(JoinNullEqTest, ProbeRowFilterSkipsOnlyNonNullEqNullKeys) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto uint8_type = std::make_shared(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({1, std::nullopt, 2}), nullable_int_type, probe_eq_key_name}, + {makeNullableInt32Column({std::nullopt, 10, 20}), nullable_int_type, probe_null_eq_key_name}, + {makeUInt8Column({1, 1, 0}), uint8_type, probe_filter_name}, + }); + + probe_process_info.prepareForHashProbe( + Names{probe_eq_key_name, probe_null_eq_key_name}, + std::vector{0, 1}, + probe_filter_name, + ASTTableJoin::Kind::Inner, + ASTTableJoin::Strictness::All, + false, + TiDB::TiDBCollators{}, + 0); + + ASSERT_NE(probe_process_info.row_filter_map, nullptr); + ASSERT_FALSE(probe_process_info.hash_join_data->key_columns[0]->isColumnNullable()); + ASSERT_TRUE(probe_process_info.hash_join_data->key_columns[1]->isColumnNullable()); + EXPECT_EQ((*probe_process_info.row_filter_map)[0], 0); + EXPECT_EQ((*probe_process_info.row_filter_map)[1], 1); + EXPECT_EQ((*probe_process_info.row_filter_map)[2], 1); +} + +TEST(JoinNullEqTest, InnerJoinNullEqNullMatchProducesJoinedRow) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::Inner); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, InnerJoinNullEqNullProbeRowDoesNotMatchNonNullBuildRow) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::Inner); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({7}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + EXPECT_EQ(probe_result.rows(), 0); +} + +TEST(JoinNullEqTest, LeftOuterNullEqNullMatchDoesNotBecomeUnmatched) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::LeftOuter); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, LeftOuterNullEqNullProbeRowStaysUnmatchedAgainstNonNullBuildRow) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::LeftOuter); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({7}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), std::nullopt); +} + +TEST(JoinNullEqTest, LeftOuterNullEqLeftConditionOnlyFiltersConditionFailure) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto uint8_type = std::make_shared(); + auto join = makeOuterJoinTestJoin( + ASTTableJoin::Kind::LeftOuter, + makeOuterJoinSideConditions(outer_probe_filter_name, "")); + prepareAndFinalizeOuterJoin(join, true, false); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt, std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10, 20}), int_type, outer_probe_value_name}, + {makeUInt8Column({1, 0}), uint8_type, outer_probe_filter_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 2); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); + + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 1), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 1), 20); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 1), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 1), std::nullopt); +} + +TEST(JoinNullEqTest, SemiJoinNullEqKeepsOnlyProbeRowsThatMatch) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeSemiJoinTestJoin(ASTTableJoin::Kind::Semi); + prepareAndFinalizeSemiJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt, 7}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100, 200}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt, 7, 8}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10, 20, 30}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 2); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 1), 7); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 1), 20); +} + +TEST(JoinNullEqTest, AntiJoinNullEqKeepsOnlyProbeRowsThatDoNotMatch) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeSemiJoinTestJoin(ASTTableJoin::Kind::Anti); + prepareAndFinalizeSemiJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt, 7}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100, 200}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt, 7, 8}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10, 20, 30}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), 8); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 30); +} + +TEST(JoinNullEqTest, MixedJoinKeysNullEqThenEqOnlyFirstKeyIsNullSafe) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeMixedKeyJoin({1, 0}); + prepareAndFinalizeMixedJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt, std::nullopt}), nullable_int_type, mixed_build_key1_name}, + {makeNullableInt32Column({1, std::nullopt}), nullable_int_type, mixed_build_key2_name}, + {makeInt32Column({100, 200}), int_type, mixed_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt, std::nullopt}), nullable_int_type, mixed_probe_key1_name}, + {makeNullableInt32Column({1, std::nullopt}), nullable_int_type, mixed_probe_key2_name}, + {makeInt32Column({10, 20}), int_type, mixed_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_key1_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_key2_name, 0), 1); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_key1_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_key2_name, 0), 1); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, MixedJoinKeysAllNullEqAllowAllNullPairToMatch) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeMixedKeyJoin({1, 1}); + prepareAndFinalizeMixedJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, mixed_build_key1_name}, + {makeNullableInt32Column({std::nullopt}), nullable_int_type, mixed_build_key2_name}, + {makeInt32Column({100}), int_type, mixed_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, mixed_probe_key1_name}, + {makeNullableInt32Column({std::nullopt}), nullable_int_type, mixed_probe_key2_name}, + {makeInt32Column({10}), int_type, mixed_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_key1_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_key2_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_key1_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_key2_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, MixedJoinKeysEqThenNullEqOnlySecondKeyIsNullSafe) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeMixedKeyJoin({0, 1}); + prepareAndFinalizeMixedJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({7, std::nullopt}), nullable_int_type, mixed_build_key1_name}, + {makeNullableInt32Column({std::nullopt, std::nullopt}), nullable_int_type, mixed_build_key2_name}, + {makeInt32Column({100, 200}), int_type, mixed_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({7, std::nullopt}), nullable_int_type, mixed_probe_key1_name}, + {makeNullableInt32Column({std::nullopt, std::nullopt}), nullable_int_type, mixed_probe_key2_name}, + {makeInt32Column({10, 20}), int_type, mixed_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_key1_name, 0), 7); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_key2_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_key1_name, 0), 7); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_key2_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, mixed_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, RightOuterNullEqNullMatchDoesNotLeakToScanAfterProbe) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::RightOuter); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); + + ASSERT_TRUE(join->finishOneProbe(0)); + join->finalizeProbe(); + + Block scan_result = readAllBlocks(join->createScanHashMapAfterProbeStream(makeOuterProbeSampleBlock(), 0, 1, 1024)); + EXPECT_EQ(scan_result.rows(), 0); +} + +TEST(JoinNullEqTest, RightOuterNullEqUnmatchedNullBuildRowStillScansFromHashMap) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::RightOuter); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ASSERT_TRUE(join->finishOneProbe(0)); + join->finalizeProbe(); + + Block scan_result = readAllBlocks(join->createScanHashMapAfterProbeStream(makeOuterProbeSampleBlock(), 0, 1, 1024)); + ASSERT_EQ(scan_result.rows(), 1); + EXPECT_EQ(getInt32Value(scan_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_probe_value_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, RightOuterNullEqRightConditionFilteredBuildRowStillScansAfterProbe) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto uint8_type = std::make_shared(); + auto join = makeOuterJoinTestJoin( + ASTTableJoin::Kind::RightOuter, + makeOuterJoinSideConditions("", outer_build_filter_name)); + prepareAndFinalizeOuterJoin(join, false, true); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt, std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100, 200}), int_type, outer_build_value_name}, + {makeUInt8Column({1, 0}), uint8_type, outer_build_filter_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); + + ASSERT_TRUE(join->finishOneProbe(0)); + join->finalizeProbe(); + + Block scan_result = readAllBlocks(join->createScanHashMapAfterProbeStream(makeOuterProbeSampleBlock(), 0, 1, 1024)); + ASSERT_EQ(scan_result.rows(), 1); + EXPECT_EQ(getInt32Value(scan_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_probe_value_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_build_value_name, 0), 200); +} + +TEST(JoinNullEqTest, FullJoinNullEqNullMatchDoesNotSplitIntoTwoUnmatchedRows) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::Full); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); + + ASSERT_TRUE(join->finishOneProbe(0)); + join->finalizeProbe(); + + Block scan_result = readAllBlocks(join->createScanHashMapAfterProbeStream(makeOuterProbeSampleBlock(), 0, 1, 1024)); + EXPECT_EQ(scan_result.rows(), 0); +} + +TEST(JoinNullEqTest, FullJoinNullEqMatchWithOtherConditionFalseKeepsBuildRowForScanAfterProbe) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::Full, makeFullJoinOtherCondition(), full_flag_helper_name); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({10}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 10); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), std::nullopt); + + ASSERT_TRUE(join->finishOneProbe(0)); + join->finalizeProbe(); + + Block scan_result = readAllBlocks(join->createScanHashMapAfterProbeStream(makeOuterProbeSampleBlock(), 0, 1, 1024)); + ASSERT_EQ(scan_result.rows(), 1); + EXPECT_EQ(getInt32Value(scan_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_probe_value_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(scan_result, outer_build_value_name, 0), 100); +} + +TEST(JoinNullEqTest, FullJoinNullEqMatchWithOtherConditionTrueConsumesBuildRow) +{ + auto nullable_int_type = makeNullable(std::make_shared()); + auto int_type = std::make_shared(); + auto join = makeOuterJoinTestJoin(ASTTableJoin::Kind::Full, makeFullJoinOtherCondition(), full_flag_helper_name); + prepareAndFinalizeOuterJoin(join); + + join->setInitActiveBuildThreads(); + join->insertFromBlock( + Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_build_key_name}, + {makeInt32Column({100}), int_type, outer_build_value_name}, + }, + 0); + ASSERT_TRUE(join->finishOneBuild(0)); + join->finalizeBuild(); + + ProbeProcessInfo probe_process_info(1024, 0); + probe_process_info.resetBlock(Block{ + {makeNullableInt32Column({std::nullopt}), nullable_int_type, outer_probe_key_name}, + {makeInt32Column({100}), int_type, outer_probe_value_name}, + }); + Block probe_result = join->joinBlock(probe_process_info); + + ASSERT_EQ(probe_result.rows(), 1); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_probe_value_name, 0), 100); + EXPECT_EQ(getInt32Value(probe_result, outer_build_key_name, 0), std::nullopt); + EXPECT_EQ(getInt32Value(probe_result, outer_build_value_name, 0), 100); + + ASSERT_TRUE(join->finishOneProbe(0)); + join->finalizeProbe(); + + Block scan_result = readAllBlocks(join->createScanHashMapAfterProbeStream(makeOuterProbeSampleBlock(), 0, 1, 1024)); + EXPECT_EQ(scan_result.rows(), 0); +} + +} // namespace DB::tests diff --git a/dbms/src/TestUtils/ColumnsToTiPBExpr.h b/dbms/src/TestUtils/ColumnsToTiPBExpr.h index ecc5ddb01e1..409a1f62f37 100644 --- a/dbms/src/TestUtils/ColumnsToTiPBExpr.h +++ b/dbms/src/TestUtils/ColumnsToTiPBExpr.h @@ -20,6 +20,7 @@ #include #include #include +#include namespace DB { diff --git a/dbms/src/TestUtils/mockExecutor.cpp b/dbms/src/TestUtils/mockExecutor.cpp index 744fe54c509..89f39928754 100644 --- a/dbms/src/TestUtils/mockExecutor.cpp +++ b/dbms/src/TestUtils/mockExecutor.cpp @@ -332,7 +332,8 @@ DAGRequestBuilder & DAGRequestBuilder::join( MockAstVec other_eq_conds_from_in, uint64_t fine_grained_shuffle_stream_count, bool is_null_aware_semi_join, - int64_t inner_index) + int64_t inner_index, + std::vector is_null_eq) { assert(root); assert(right.root); @@ -342,6 +343,7 @@ DAGRequestBuilder & DAGRequestBuilder::join( right.root, tp, join_col_exprs, + is_null_eq, left_conds, right_conds, other_conds, diff --git a/dbms/src/TestUtils/mockExecutor.h b/dbms/src/TestUtils/mockExecutor.h index abd2f630923..ba89e821218 100644 --- a/dbms/src/TestUtils/mockExecutor.h +++ b/dbms/src/TestUtils/mockExecutor.h @@ -146,7 +146,8 @@ class DAGRequestBuilder MockAstVec other_eq_conds_from_in, uint64_t fine_grained_shuffle_stream_count = 0, bool is_null_aware_semi_join = false, - int64_t inner_index = 1); + int64_t inner_index = 1, + std::vector is_null_eq = {}); DAGRequestBuilder & join( const DAGRequestBuilder & right, tipb::JoinType tp, diff --git a/dbms/src/TestUtils/tests/gtest_mock_executors.cpp b/dbms/src/TestUtils/tests/gtest_mock_executors.cpp index ff29392180f..2a41fb615fa 100644 --- a/dbms/src/TestUtils/tests/gtest_mock_executors.cpp +++ b/dbms/src/TestUtils/tests/gtest_mock_executors.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -42,6 +43,18 @@ class MockDAGRequestTest : public DB::tests::ExecutorTest } }; +namespace +{ +void collectColumnRefs(const tipb::Expr & expr, std::vector & column_refs) +{ + if (expr.tp() == tipb::ExprType::ColumnRef) + column_refs.push_back(&expr); + + for (const auto & child : expr.children()) + collectColumnRefs(child, column_refs); +} +} // namespace + TEST_F(MockDAGRequestTest, MockTable) try { @@ -265,6 +278,100 @@ try } CATCH +TEST_F(MockDAGRequestTest, JoinNullEqSerialization) +try +{ + auto request = context.scan("test_db", "l_table") + .join( + context.scan("test_db", "r_table"), + tipb::JoinType::TypeInnerJoin, + {col("join_c")}, + {}, + {}, + {}, + {}, + 0, + false, + 1, + {1}) + .build(context); + + ASSERT_EQ(request->root_executor().tp(), tipb::ExecType::TypeJoin); + const auto & join = request->root_executor().join(); + ASSERT_EQ(join.is_null_eq_size(), 1); + ASSERT_TRUE(join.is_null_eq(0)); +} +CATCH + +TEST_F(MockDAGRequestTest, FullOuterJoinSchemaIsNullable) +try +{ + context.addMockTable( + {"full_outer_test", "l"}, + {{"a", TiDB::TP::TypeLong, false}, {"b", TiDB::TP::TypeLong, false}}); + context.addMockTable( + {"full_outer_test", "r"}, + {{"a", TiDB::TP::TypeLong, false}, {"c", TiDB::TP::TypeLong, false}}); + + auto request = context.scan("full_outer_test", "l") + .join( + context.scan("full_outer_test", "r"), + tipb::JoinType::TypeFullOuterJoin, + {col("a")}, + {}, + {}, + {gt(col("b"), col("c"))}, + {}) + .build(context); + + ASSERT_EQ(request->root_executor().tp(), tipb::ExecType::TypeJoin); + const auto & join = request->root_executor().join(); + ASSERT_EQ(join.other_conditions_size(), 1); + + std::vector column_refs; + collectColumnRefs(join.other_conditions(0), column_refs); + ASSERT_EQ(column_refs.size(), 2); + for (const auto * column_ref : column_refs) + ASSERT_EQ(column_ref->field_type().flag() & TiDB::ColumnFlagNotNull, 0); + + auto output_field_types = collectOutputFieldTypes(*request); + ASSERT_EQ(output_field_types.size(), 4); + for (const auto & field_type : output_field_types) + ASSERT_EQ(field_type.flag() & TiDB::ColumnFlagNotNull, 0); +} +CATCH + +TEST_F(MockDAGRequestTest, SemiJoinColumnPruneKeepsJoinOutputSchema) +try +{ + const std::vector> test_cases = { + {tipb::JoinType::TypeSemiJoin, 3}, + {tipb::JoinType::TypeAntiSemiJoin, 3}, + {tipb::JoinType::TypeLeftOuterSemiJoin, 4}, + {tipb::JoinType::TypeAntiLeftOuterSemiJoin, 4}, + }; + + for (const auto & [join_type, expected_size] : test_cases) + { + auto request = context.scan("test_db", "l_table") + .join(context.scan("test_db", "r_table"), join_type, {col("join_c")}) + .build(context); + + auto output_field_types = collectOutputFieldTypes(*request); + ASSERT_EQ(output_field_types.size(), expected_size) << fmt::underlying(join_type); + ASSERT_EQ(output_field_types[0].tp(), TiDB::TypeLong); + ASSERT_EQ(output_field_types[1].tp(), TiDB::TypeString); + ASSERT_EQ(output_field_types[2].tp(), TiDB::TypeString); + + if (expected_size == 4) + { + ASSERT_EQ(output_field_types[3].tp(), TiDB::TypeTiny); + ASSERT_EQ(output_field_types[3].flag() & TiDB::ColumnFlagNotNull, 0); + } + } +} +CATCH + TEST_F(MockDAGRequestTest, ExchangeSender) try { diff --git a/dbms/src/TiDB/Decode/JsonBinary.cpp b/dbms/src/TiDB/Decode/JsonBinary.cpp index 355aa487b66..4a51dd008bc 100644 --- a/dbms/src/TiDB/Decode/JsonBinary.cpp +++ b/dbms/src/TiDB/Decode/JsonBinary.cpp @@ -1098,6 +1098,76 @@ void JsonBinary::buildBinaryJsonArrayInBuffer( buildBinaryJsonElementsInBuffer(json_binary_vec, write_buffer); } +void JsonBinary::buildBinaryJsonObjectInBuffer( + const std::vector & keys, + const std::vector & values, + JsonBinaryWriteBuffer & write_buffer) +{ + RUNTIME_CHECK(keys.size() == values.size()); + + write_buffer.write(TYPE_CODE_OBJECT); + + UInt32 buffer_start_pos = write_buffer.offset(); + + UInt32 element_count = keys.size(); + encodeNumeric(write_buffer, element_count); + + auto total_size_pos = write_buffer.offset(); + write_buffer.advance(4); + + UInt32 data_offset_start = HEADER_SIZE + element_count * (KEY_ENTRY_SIZE + VALUE_ENTRY_SIZE); + UInt32 data_offset = data_offset_start; + for (const auto & key : keys) + { + encodeNumeric(write_buffer, data_offset); + if (unlikely(key.size > std::numeric_limits::max())) + throw Exception("TiDB/TiFlash does not yet support JSON objects with the key length >= 65536"); + UInt16 key_len = key.size; + encodeNumeric(write_buffer, key_len); + data_offset += key.size; + } + + UInt32 value_entry_start_pos = write_buffer.offset(); + + write_buffer.setOffset(buffer_start_pos + data_offset_start); + for (const auto & key : keys) + write_buffer.write(key.data, key.size); + + write_buffer.setOffset(value_entry_start_pos); + UInt64 max_child_depth = 0; + for (const auto & value : values) + { + write_buffer.write(value.type); + if (value.type == TYPE_CODE_LITERAL) + { + write_buffer.write(value.data.data[0]); + write_buffer.write(0); + write_buffer.write(0); + write_buffer.write(0); + } + else + { + encodeNumeric(write_buffer, data_offset); + auto tmp_entry_pos = write_buffer.offset(); + + write_buffer.setOffset(buffer_start_pos + data_offset); + write_buffer.write(value.data.data, value.data.size); + data_offset = write_buffer.offset() - buffer_start_pos; + + write_buffer.setOffset(tmp_entry_pos); + } + max_child_depth = std::max(max_child_depth, value.getDepth()); + } + + UInt64 depth = max_child_depth + 1; + JsonBinary::assertJsonDepth(depth); + + UInt32 total_size = data_offset; + write_buffer.setOffset(total_size_pos); + encodeNumeric(write_buffer, total_size); + write_buffer.setOffset(buffer_start_pos + data_offset); +} + void JsonBinary::buildKeyArrayInBuffer(const std::vector & keys, JsonBinaryWriteBuffer & write_buffer) { write_buffer.write(TYPE_CODE_ARRAY); diff --git a/dbms/src/TiDB/Decode/JsonBinary.h b/dbms/src/TiDB/Decode/JsonBinary.h index c65a68723ec..d13ca157db6 100644 --- a/dbms/src/TiDB/Decode/JsonBinary.h +++ b/dbms/src/TiDB/Decode/JsonBinary.h @@ -173,6 +173,10 @@ class JsonBinary static void buildBinaryJsonArrayInBuffer( const std::vector & json_binary_vec, JsonBinaryWriteBuffer & write_buffer); + static void buildBinaryJsonObjectInBuffer( + const std::vector & keys, + const std::vector & values, + JsonBinaryWriteBuffer & write_buffer); static void buildKeyArrayInBuffer(const std::vector & keys, JsonBinaryWriteBuffer & write_buffer); static void appendNumber(JsonBinaryWriteBuffer & write_buffer, bool value); @@ -327,4 +331,4 @@ void JsonBinary::unquoteJsonStringInBuffer(const StringRef & ref, WriteBuffer & } } } -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/docs/note/fullouter_join.md b/docs/note/fullouter_join.md new file mode 100644 index 00000000000..44956fa625f --- /dev/null +++ b/docs/note/fullouter_join.md @@ -0,0 +1,272 @@ +# TiFlash 支持 FULL OUTER JOIN 改动点梳理(本轮仅等值 Join) + +## 背景 + +TiFlash 内核(继承自 ClickHouse 的 Join 架构)本身有 `ASTTableJoin::Kind::Full` 的基础能力,但在 TiDB 长期不下推 full outer join 的情况下,这条路径没有被持续覆盖,尤其是后续补上的 `other condition` 逻辑基本只覆盖了 left/right outer。 + +现在 TiDB 侧准备支持 full outer join,下推到 TiFlash 后,需要把协议映射、输出 schema、执行语义和测试覆盖一起补齐。 + +本文档本轮 scope 明确限定为:`FULL OUTER JOIN` 且 `left_join_keys/right_join_keys` 非空(即有等值 join key 的 hash join 路径)。 + +## 现状结论 + +### 已有能力(可复用) + +1. SQL/AST 层已识别 Full。 +- `dbms/src/Parsers/ParserTablesInSelectQuery.cpp:117` +- `dbms/src/Parsers/ASTTablesInSelectQuery.cpp:165` + +2. Hash Join 框架里,`Full` 已被视为需要“probe 后扫描 build 侧未匹配行”的 join。 +- `dbms/src/Interpreters/JoinUtils.h:26` +- `dbms/src/Interpreters/JoinUtils.h:84` + +3. 不带 `other condition` 的 full 基础路径基本存在。 +- probe 侧列 nullable 处理:`dbms/src/Interpreters/ProbeProcessInfo.cpp:71` +- build 侧样本列 nullable 处理:`dbms/src/Interpreters/Join.cpp:346` +- build block nullable 处理:`dbms/src/Interpreters/Join.cpp:696` +- full 在 probe 时走 `MapsAllFull`:`dbms/src/Interpreters/JoinPartition.cpp:2124` + +### 主要缺口(必须补) + +1. TiDB DAG 协议和 TiFlash JoinType 映射还没有 full。 +- `contrib/tipb/proto/executor.proto:184` +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp:68` + +2. Full 的输出/中间 schema nullable 规则没有补齐(只处理 left/right outer)。 +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp:298` +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp:333` +- `dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp:175` + +3. `full + other condition` correctness 当前不成立。 +- `handleOtherConditions` 没有 full 分支,可能直接抛逻辑错误:`dbms/src/Interpreters/Join.cpp:1032` +- full 当前走 `MapsAllFull`,key 命中时会提前 `setUsed()`,other condition 过滤后无法恢复“未匹配右行”状态:`dbms/src/Interpreters/JoinPartition.cpp:1601` + +4. left/right condition 校验目前不允许 full。 +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h:97` + +5. 无等值键(join keys 为空)的 full 当前不在本轮范围内(本轮先不打通 cartesian full)。 + +## 具体改动点 + +## 1. 协议与 JoinType 映射 + +1. 在 tipb 加 `TypeFullOuterJoin`(按 TiDB 最终协议值同步)。 +- 位置:`contrib/tipb/proto/executor.proto` + +2. `JoinInterpreterHelper::getJoinKindAndBuildSideIndex` 增加 full 映射。 +- equal join(有 key)至少支持: + - `{TypeFullOuterJoin, 0} -> {ASTTableJoin::Kind::Full, 0}` + - `{TypeFullOuterJoin, 1} -> {ASTTableJoin::Kind::Full, 1}` +- 约定(与 TiDB 测试一致):`FULL OUTER JOIN` 的 build side 由 `inner_idx` 直接指定(即 `build_side_index == inner_idx`)。 +- 说明:full 场景不需要像 left/right outer 那样改 `join kind`;但执行层仍会按 `build_side_index` 做 probe/build 角色接线,以满足 TiFlash 内部 right-build 约定。 + +3. 字符串化/日志分支补 full,避免默认分支报错。 +- `dbms/src/Flash/Coprocessor/DAGUtils.cpp:837` + +## 2. Full 的 nullable/schema 规则 + +1. Join 输出 schema:full 需要左右两侧都 nullable。 +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp:333` + +2. other condition 编译输入 schema:full 需要左右两侧(以及 probe prepare 新增列)按 full 语义 nullable。 +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp:298` +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp:313` + +3. `collectOutputFieldTypes` 对外字段类型同样要把左右都改为 nullable。 +- `dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp:175` +- `dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp:213` + +## 3. left/right condition 在 full 下的校验 + +`JoinNonEqualConditions::validate` 当前把 left condition 只限定在 left outer,把 right condition 只限定在 right outer;full 下这两类都可能出现,应允许。 + +- 位置:`dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h:95` + +## 4. 核心正确性:full + other condition + +这是本次最关键的改动。 + +当前问题: + +1. full 在 probe 阶段使用 `MapsAllFull`,命中 key 就先 `setUsed()`。 +2. 后续 `other condition` 再过滤时,可能把这些行过滤掉。 +3. 但 build 侧“已使用”标记已经被置位,probe 后扫描阶段不会再输出这些本应作为“未匹配右行”的记录。 + +需要改成:`used` 只能在 `other condition` 通过后再设置。 + +实现上,下面 1-6 点虽然可以按职责拆成“标记时机 / `handleOtherConditions` 语义 / probe 后扫描联动”三个子问题,但当前代码路径存在强耦合: + +1. 只切 row-flagged map,而不同时补 full 的 `handleOtherConditions` 分支,`full + other condition` 仍可能落到运行时错误分支。 +2. 只切 probe 侧 map,而不同时补 probe 后扫描分支,后扫阶段仍会读取错误的 hash map 类型。 + +因此建议把下面 1-6 点至少作为同一个 review/提交单元交付;文档后面的第 6/7/8 步保留为概念拆分,便于逐项验收,但不建议机械地拆成三个互相独立、可单独合入的 patch。 + +建议实现方向: + +1. 让 full 在 `has_other_condition=true` 时也走 row-flagged map(`MapsAllFullWithRowFlag`)。 +- 受影响点: + - `dbms/src/Interpreters/JoinUtils.h:89` + - `dbms/src/Interpreters/JoinPartition.cpp:281` + - `dbms/src/Interpreters/JoinPartition.cpp:971` + - `dbms/src/Interpreters/JoinPartition.cpp:1037` + +2. `JoinPartition::probeBlock` 增加 full + row_flagged 的 dispatch 分支。 +- 当前 full 固定走 `MapsAllFull`:`dbms/src/Interpreters/JoinPartition.cpp:2124` + +3. row-flagged probe adder 需要支持 full 的“左侧保底输出”语义。 +- 现在 `RowFlaggedHashMapAdder::addNotFound` 不补默认行:`dbms/src/Interpreters/JoinPartition.cpp:1450` +- full 需要 not-found 时仍输出 1 行(右侧默认值)并给 helper 列写可判空值。 + +4. `Join::handleOtherConditions` 增加 full 分支,语义对齐 left outer 的保底行为。 +- 当前只在 `isLeftOuterJoin(kind)` 时做“至少保留一行 + 右侧置 null”逻辑: + - `dbms/src/Interpreters/Join.cpp:999` + - `dbms/src/Interpreters/Join.cpp:1017` + +5. `Join::doJoinBlockHash` 在 full+row_flagged 下: +- 根据 helper 指针给 build row 打 used 标记时要跳过空指针。 +- 结果输出前移除 helper 临时列。 +- 相关位置:`dbms/src/Interpreters/Join.cpp:1325` + +6. probe 后扫描阶段对 full+other_condition 要切到 row_flagged 分支。 +- `dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp:256` + +## 5. 无等值键 full(cartesian full)本轮不做 + +当前代码对 full 无 key 没有成型执行路径,但本轮按 scope 不打通该路径。建议只做明确保护,避免误走: + +1. 若收到 `TypeFullOuterJoin` 且 join keys 为空,TiFlash 显式报 `Unimplemented/BadRequest`,错误信息写清楚“cartesian full 不支持”。 +2. 后续如需支持,再单独立项实现 cross full 路径(工作量明显高于 equal full)。 + +建议在 `getJoinKindAndBuildSideIndex`(或更上游)就明确拦截,避免落到“Unknown join type”这类难排查报错。 + +## 6. Debug/Mock 与测试构造链路 + +为保证 gtest 能构造 full join DAG,需要同步补 mock binder。 + +1. mock schema nullable 规则补 full(左右都 nullable)。 +- `dbms/src/Debug/MockExecutor/JoinBinder.cpp:99` +- `dbms/src/Debug/MockExecutor/JoinBinder.cpp:296` + +2. AST -> tipb 的旧测试编译路径补 full。 +- `dbms/src/Debug/MockExecutor/JoinBinder.cpp:384` + +## 7. 测试改动建议 + +最低覆盖建议: + +1. Coprocessor 映射测试 +- `dbms/src/Flash/Coprocessor/tests/gtest_join_get_kind_and_build_index.cpp` +- 新增 full + `inner_idx=0/1`(有 key) +- 明确断言 `build_side_index == inner_idx`(full 场景) +- 可补 1 个无 key 的拒绝用例(验证报错信息),但不实现执行逻辑 + +2. Join Executor correctness +- 参考现有 `RightOuterJoin` 的构造模式:`dbms/src/Flash/tests/gtest_join_executor.cpp:4031` +- 重点新增 full 场景: + - 有 key、无 other condition + - 有 key、有 other condition + - key 命中但 other condition 全失败(必须同时输出 left-unmatched 和 right-unmatched) + - 含 left/right conditions + - 含 null key + +3. Spill / Fine-grained shuffle +- 至少补 1 组 full + other condition 的 spill case +- 现有 join type 数组都还是 7 种: + - `dbms/src/Flash/tests/gtest_join.h:184` + - `dbms/src/Flash/tests/gtest_compute_server.cpp:1248` +- 不建议直接把全量数组扩成 8 再重写大批 expected,可先加 targeted full 用例。 + +## 建议落地顺序 + +1. 同步 tipb(拿到 `TypeFullOuterJoin`)。 +2. 打通 JoinType 映射 + schema nullable + `getJoinTypeName`。 +3. 放开 full 的 left/right condition 校验。 +4. 把 `full + other condition` 的 correctness 主链路一起打通(row-flagged 路径 + `handleOtherConditions` + probe 后扫描联动)。 +5. 增加无等值键 full 的显式拒绝(本轮不实现)。 +6. 补 gtest(先 targeted,再考虑扩大 join type 矩阵)。 + +## 开发步骤(可执行 checklist) + +1. 第 1 步:协议枚举打通(仅 full,且仅等值 key) +- 目标:TiFlash 能识别 `TypeFullOuterJoin`。 +- 修改: + - `contrib/tipb/proto/executor.proto` 增加 `TypeFullOuterJoin`。 + - 生成/同步对应 protobuf 代码(按仓库现有流程)。 +- 验收: + - 编译无 `JoinType` 相关 enum 错误。 + - `TypeFullOuterJoin` 可在 TiFlash 代码中引用。 + +2. 第 2 步:JoinType 映射与 build side 约定 +- 目标:full 映射到 `ASTTableJoin::Kind::Full`,并满足 `build_side_index == inner_idx`。 +- 修改: + - `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp` + - `dbms/src/Flash/Coprocessor/DAGUtils.cpp`(`getJoinTypeName`) + - 若需要,`dbms/src/Flash/Coprocessor/tests/gtest_join_get_kind_and_build_index.cpp` +- 验收: + - full + `inner_idx=0/1` 都能返回 `kind=Full` 且 `build_side_index` 与 `inner_idx` 一致。 + +3. 第 3 步:等值 key scope 保护(无 key 显式拒绝) +- 目标:本轮不实现 cartesian full,收到此类请求时报清晰错误。 +- 修改: + - 建议在 `JoinInterpreterHelper::getJoinKindAndBuildSideIndex(...)` 或更上游添加 guard。 +- 验收: + - `TypeFullOuterJoin` 且 `join_keys_size==0` 时,报错信息明确包含“不支持 cartesian full”。 + +4. 第 4 步:full 的输出与 other-condition 输入 schema 全量 nullable +- 目标:full 下左右输出列、other-condition 编译输入列都按 nullable 处理。 +- 修改: + - `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp` + - `dbms/src/Flash/Coprocessor/collectOutputFieldTypes.cpp` + - `dbms/src/Debug/MockExecutor/JoinBinder.cpp`(测试构造链路) +- 验收: + - full 输出 schema 中左右列都是 nullable(不影响 semi 系列)。 + +5. 第 5 步:放开 full 的 left/right conditions 校验 +- 目标:full 可带 left/right condition,不被 `validate` 拦截。 +- 修改: + - `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h`(`JoinNonEqualConditions::validate`) +- 验收: + - full + left_condition、full + right_condition 的构造阶段不报 “non left/right join with ... conditions”。 + +6. 第 6 步:核心改造 A - full + other condition 的 used 标记时机(建议与第 7、8 步合并交付) +- 目标:只有当 `other condition` 通过时,build 行才标记 used,避免漏输出未匹配右行。 +- 说明: + - 当前代码上,row-flagged map 选择、`handleOtherConditions` 的 full 语义、probe 后扫描分支是强耦合的。 + - 如果只落本步、不同时补第 7 和第 8 步,代码可能虽然能部分编译,但 `full + other condition` 仍无法形成可运行、可验证的完整链路。 +- 修改: + - `dbms/src/Interpreters/JoinUtils.h`(让 full+other_condition 进入 row-flagged 路径) + - `dbms/src/Interpreters/JoinPartition.cpp`(map 初始化、probe dispatch、adder not-found 行为) + - `dbms/src/Interpreters/Join.cpp`(`doJoinBlockHash` 标记 used 时跳过空指针) +- 验收: + - 场景:key 命中但 `other condition` 全失败时,右侧行仍能在 probe 后扫描阶段输出。 + +7. 第 7 步:核心改造 B - `handleOtherConditions` full 语义(通常在第 6 步一并落地) +- 目标:full 下也要有“外连接至少保留一行 + 右侧置 null”的正确行为。 +- 修改: + - `dbms/src/Interpreters/Join.cpp`(`handleOtherConditions` 分支) +- 验收: + - full + other_condition 返回结果与“left/right outer 的对称组合语义”一致。 + +8. 第 8 步:probe 后扫描阶段联动(通常在第 6 步一并落地) +- 目标:full+other_condition 使用 row-flagged map 扫描未匹配右行。 +- 修改: + - `dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp` +- 验收: + - full+other_condition 能正确输出 build 侧未匹配行,不丢行、不重复。 + +9. 第 9 步:测试补齐(先 targeted) +- 目标:先保证 correctness,再扩矩阵。 +- 修改建议: + - `dbms/src/Flash/Coprocessor/tests/gtest_join_get_kind_and_build_index.cpp` + - `dbms/src/Flash/tests/gtest_join_executor.cpp` + - `dbms/src/Flash/tests/gtest_spill_join.cpp`(至少 1 组) +- 验收最小集: + - full + key + 无 other_condition + - full + key + 有 other_condition + - full + key 命中但 other_condition 全失败(关键回归) + - full + left/right conditions + - full + null key + +## 一句话风险提示 + +如果只做“JoinType 映射 + 输出 nullable”而不改 row-flagged 逻辑,`FULL OUTER JOIN ... ON eq_key AND other_condition` 会出现右侧漏行,属于结果错误而非性能问题,必须和协议打通一起完成。 diff --git a/docs/note/nulleq_join.md b/docs/note/nulleq_join.md new file mode 100644 index 00000000000..56211d59e0f --- /dev/null +++ b/docs/note/nulleq_join.md @@ -0,0 +1,601 @@ +# TiFlash NullEQ Join Key(`<=>` / `tidbNullEQ`)设计文档 + +## 背景 + +TiFlash 当前 Hash Join 的默认等值语义是: + +- `NULL` 不参与等值匹配 +- build 侧含 `NULL` key 的行不会进入 hash map +- probe 侧含 `NULL` key 的行会被直接当作 not matched + +这与 Null-safe equal(`<=>` / `tidbNullEQ`)要求的语义不同: + +- `NULL <=> NULL` 为 `true` +- `NULL <=> non-NULL` 为 `false` +- `non-NULL <=> non-NULL` 与普通 `=` 一致 + +本设计文档讨论的是:在 TiFlash 已经支持 `FULL OUTER JOIN` 的前提下,如何为 hash join 增加 **join key 粒度的 NullEQ 语义**。 + +## 目标 + +1. TiFlash 的 Hash Join 在 **join key** 使用 NullEQ 语义时结果正确。 +2. 支持 **key 粒度混合语义**: + - 同一个 join 中允许部分 key 走 `=` + - 允许部分 key 走 `<=>` +3. 未下发 NullEQ 标记时,保持现有行为不变。 +4. 与已经支持的 `FULL OUTER JOIN` 语义兼容,不引入 `NULL <=> NULL` 相关的 outer join 错误结果。 + +## 非目标 + +1. 不把 `other_conditions` 中的 `<=>` 纳入本文新增的 join-key NullEQ 语义范围;若 planner 下发了这类表达式,按普通 other condition 处理。 +2. MVP 不追求最优性能,允许为了正确性强制走 `serialized`。 +3. 不扩大 NullAware join(`NOT IN` 家族)的语义覆盖范围;MVP 阶段建议 fail-fast。 +4. 不在本轮实现 cartesian full join 与 NullEQ 的组合。 + +## 作用范围 + +本轮 scope 限定为: + +- hash join +- `left_join_keys/right_join_keys` 非空 +- NullEQ 只出现在 join key 上 + +不包含: + +- cartesian join +- 仅通过 `other_conditions` 表达 NullEQ join 语义的计划形态 +- planner 把 `<=>` 重写成其它表达式后再由 TiFlash 反推语义 + +## 输入契约 + +### tipb 协议 + +建议在 `tipb::Join` 中增加: + +- `repeated bool is_null_eq = ...;` + +语义如下: + +- `is_null_eq[i] = false`:第 `i` 对 join key 使用普通 `=` +- `is_null_eq[i] = true`:第 `i` 对 join key 使用 `<=>` + +长度约束: + +- `is_null_eq_size == 0`:视为全 `false`,兼容旧版本 +- 否则必须满足: + - `is_null_eq_size == left_join_keys_size` + - `is_null_eq_size == right_join_keys_size` + +### Join key 表达形式 + +MVP 假设 join key 由 planner 下发为列引用: + +- `left_join_keys[i]` 与 `right_join_keys[i]` 是一一对齐的 key pair +- TiFlash 如需在执行层插入 cast 做类型对齐,不改变 key 的顺序和数量 +- `is_null_eq[i]` 始终按 key pair index 对齐,而不是按 build/probe 角色对齐 + +### 语义边界 + +NullEQ 语义只通过 `is_null_eq[]` 表达: + +- 若 `<=>` 出现在 `other_conditions` 中,则按普通布尔表达式处理;本文不要求执行层从 `other_conditions` 里的 `<=>` 反推出“这是 join key NullEQ” +- 不让执行层从通用表达式里反推某个 key 是否是 NullEQ + +### 与 NullAware join 的关系 + +`is_null_aware_semi_join` 与 NullEQ 是两套不同语义: + +- NullAware join 关注 `NOT IN` 的三值逻辑 +- NullEQ 关注 join key 的比较语义 + +MVP 建议: + +- 若 `is_null_aware_semi_join=true` 且存在任意 `is_null_eq[i]=true`,直接 fail-fast + +原因是这两条路径都对“NULL key 行怎么处理”有强假设,混用很容易产生 silent wrong result。 + +## 当前实现的关键假设 + +当前 Join 框架里,与 NullEQ 直接冲突的假设主要有四类。 + +### 1. key-NULL 会被提前过滤 + +当前 build/probe 都会把 nullable key 做两件事: + +1. 把 `ColumnNullable` 替换成 nested column +2. 把 key 中的 `NULL` 行写入 `null_map` + +对应路径: + +- build:`Join::insertFromBlockInternal()` +- probe:`ProbeProcessInfo::prepareForHashProbe()` + +这意味着: + +- build 侧 `NULL` key 行默认不入 map +- probe 侧 `NULL` key 行默认不 probe map + +这与 NullEQ 的 `NULL <=> NULL` 可以匹配直接冲突。 + +### 2. side-condition 与 key-NULL 共用一张 null_map + +当前 `recordFilteredRows()` 会复用同一张 `null_map`,把 side-condition 过滤结果与 “key 是否为 NULL” 混到一起。 + +对 NullEQ 来说,问题不在于执行链路必须长期维护两张独立的 map,而在于: + +- 普通 `=` key 的 `NULL` 过滤 +- left/right side-condition 过滤 + +这两类来源在**生成过滤结果**时必须区分,因为: + +- 对普通 `=` key,应把 key 的 `NULL` 写入最终过滤结果 +- 对 NullEQ key,不应把 key 的 `NULL` 写入最终过滤结果 +- side-condition 的过滤结果则始终需要写入最终过滤结果 + +因此更准确的做法是: + +- 先按 key 粒度决定哪些 `NULL` 需要参与过滤 +- 再与 side-condition 的过滤结果合并成一张统一的 `row_filter_map` + +也就是说,最终可以只有一张“这一行是否跳过 insert/probe”的 map,但不能继续沿用当前这种“先无差别把所有 key-NULL 都写进 null_map,再复用它叠加 side-condition”的实现方式。 + +### 3. RowsNotInsertToMap / scan-after-probe 默认把 NULL key 当作天然 unmatched + +对于 right/full/right semi/right anti/null-aware 这些需要保留 build 侧特殊行的 join kind,当前实现会把“未入 map 的 build 行”记进 `RowsNotInsertToMap`,之后在 scan-after-probe 阶段输出。 + +在普通 `=` 语义下这成立,因为 key-NULL 本来就不参与匹配。 + +但在 NullEQ 语义下: + +- `NULL` key 行不一定是 unmatched +- 它可能应该入 map,并与 probe 侧 `NULL` key 行成功匹配 + +### 4. KeyGetter 默认不编码 nullable bitmap + +当前 `keys128/keys256` 这类 fixed key hash method 默认是 `has_nullable_keys = false`。 + +这意味着即使不提前过滤 `NULL`,现有 packed key 路径也未必能正确把 nullness 编进 hash key。 + +## 与 FULL OUTER JOIN 的额外交互 + +NullEQ 本身不是 `FULL OUTER JOIN` 专属问题,`LEFT OUTER JOIN` 和 `RIGHT OUTER JOIN` 也会受影响。 +但在 TiFlash 已经支持 `FULL OUTER JOIN` 之后,有几件事必须在设计里显式纳入,否则很容易出现双边都错的结果。 + +### 1. “NULL key 走天然 unmatched 路径”不是 full 特有问题,但 full 会把问题放大 + +这件事对不同 outer join 的影响不同: + +- `LEFT OUTER JOIN` + - probe 侧 `NULL` key 若仍直接走 `addNotFound()`,本该命中的 `NULL <=> NULL` 会被错误输出成左 unmatched +- `RIGHT OUTER JOIN` + - build 侧 `NULL` key 若仍进 `RowsNotInsertToMap`,本该命中的行会在 scan-after-probe 阶段被错误输出成右 unmatched +- `FULL OUTER JOIN` + - 上述两条路径会同时存在 + - 一组本该匹配的 `NULL <=> NULL` 行,可能被错误拆成: + - 一条左 unmatched + - 一条右 unmatched + +所以这不是 full 独有问题,但 full 会把问题表现得最明显、也最复杂。 + +### 2. FULL + other condition 必须继续沿用“延后 setUsed”语义 + +当前 full 分支已经为 `full + other condition` 做了专门修正: + +- key 命中时不能立刻 `setUsed()` +- 必须等 `other condition` 真正通过后,再标记 build 行为 used + +否则 probe 后扫描阶段会漏输出本应作为 unmatched build 行的记录。 + +在 NullEQ 引入后,这个约束仍然成立,而且要覆盖 `NULL <=> NULL` 命中的情况: + +- key 通过是因为 `NULL <=> NULL` +- other condition 失败 +- 正确语义应该是: + - 左侧保留一条右补 null 的 unmatched 行 + - 右侧 build 行仍然在后扫阶段输出为 unmatched + +因此 NullEQ 不能绕开 full 分支当前的 row-flagged / delayed-used 设计。 + +### 3. RowsNotInsertToMap 在 full 下要重新定义语义 + +在支持 full 之后,`RowsNotInsertToMap` 不能再简单理解成“所有 NULL key 行 + 所有 build condition 失败的行”。 + +更准确的语义应该是: + +- build side-condition 失败的行 +- 普通 `=` key 因 key-NULL 被过滤的行 + +不应包含: + +- NullEQ key 为 `NULL` 的行 + +因为这些行应该入 map,并可能成功匹配。 + +### 4. dispatch hash / spill / fine-grained shuffle 在 full 下更容易暴露错误 + +如果 build/probe 的 dispatch hash 没有把 nullness 编进 key: + +- build 侧 `NULL` key 与 probe 侧 `NULL` key 可能落到不同 partition +- inner join 下通常表现为“不命中” +- full join 下则可能进一步演变成: + - probe 侧输出一条 unmatched + - build 侧后扫再输出一条 unmatched + +所以 `full + NullEQ + spill/FGS` 应该是 MVP 测试矩阵里的必测项,而不是后续补充项。 + +### 5. full 的 schema nullable 规则不需要为 NullEQ 再单独扩展 + +这一点反而不用新增复杂度: + +- full 输出 schema 两边本来就都应为 nullable +- other-condition 输入 schema 两边也已经按 full 语义处理成 nullable + +NullEQ 改变的是 **匹配语义**,不是 full 输出 schema 的 nullable 规则。 + +## 设计选择 + +## 1. 总体原则 + +NullEQ 设计遵循两个核心原则: + +1. 把“key 是否为 NULL”与“row 是否因 side-condition 被过滤”分离 +2. 让 NullEQ 的 `NULL` 真正进入 key 比较,而不是继续被当成特殊 unmatched 行 + +由此得到两个概念: + +- `row_filter_map` + - 表示这一行不需要 insert/probe + - 原因可以是 left/right condition 失败,也可以是普通 `=` key 的 `NULL` +- `key_null_map` + - 只对普通 `=` key 有意义 + - NullEQ key 不应把 `NULL` 写进这张 map + +这里保留这两个名字,主要是为了说明“过滤结果的来源”。 + +最终实现里,它们完全可以合并成一张统一的 `row_filter_map`: + +- 普通 `=` key 的 `NULL` 可以进入这张 map +- left/right side-condition 的过滤结果也进入这张 map +- 但 NullEQ key 的 `NULL` 不能进入这张 map + +换句话说,关键不是最终一定要维护两张独立的 map,而是生成最终过滤结果时,必须按 key 粒度决定哪些 `NULL` 应该被当作“跳过 insert/probe”的条件。 + +## 2. build/probe 都按 key 粒度区分 `=` 与 `<=>` + +对于每个 key pair: + +- 若 `is_null_eq[i] = false` + - 延续现有 `=` 语义 + - key 中有 `NULL` 时,这一行不参与匹配 +- 若 `is_null_eq[i] = true` + - 保留 nullable key + - `NULL` 可以参与 hash / probe / match + +也就是说,NullEQ 不是“整条 join 全都变 null-safe”,而是按 key pair 生效。 + +## 3. build 路径设计 + +build 阶段的目标是: + +- NullEQ key 为 `NULL` 的行可以入 map +- 普通 `=` key 为 `NULL` 的行仍不入 map +- side-condition 失败的行不入 map,但 outer join 语义所需的保底输出仍要保留 + +建议做法: + +1. 从原始 key columns 出发,不再无条件对所有 key 调用 `extractNestedColumnsAndNullMap` +2. 遍历每个 key: + - 对 `=` key: + - 若是 nullable,则取 nested column + - 并把该列 null map OR 进 `row_filter_map` + - 对 `<=>` key: + - 保留 `ColumnNullable` + - 不把该列 null map 写进 `row_filter_map` +3. 再把 build side-condition 的过滤结果 OR 进 `row_filter_map` +4. 传给 `JoinPartition::insertBlockIntoMaps(..., row_filter_map, ...)` + +这样 build 路径上的语义就变成: + +- `row_filter_map[i] = 1` + - 这一行不入 map +- `row_filter_map[i] = 0` + - 这一行入 map + +对于 full/right outer/right semi/right anti 这些会记录 build 特殊行的 join kind: + +- 只有 side-condition 失败的行、或者普通 `=` key 的 `NULL` 行,才进入 `RowsNotInsertToMap` +- NullEQ key 的 `NULL` 行不应进入 `RowsNotInsertToMap` + +## 4. probe 路径设计 + +probe 阶段的目标是: + +- NullEQ key 为 `NULL` 的行可以真正 probe map +- 普通 `=` key 的 `NULL` 行仍然按“不匹配”处理 +- left/full outer 语义下,probe unmatched 的保底输出仍然正确 + +建议做法与 build 对称: + +1. 不再无条件对所有 key 做 `extractNestedColumnsAndNullMap` +2. 逐 key 处理: + - `=` key 的 `NULL` 写入 `row_filter_map` + - `<=>` key 保留 nullable,不写入 `row_filter_map` +3. 再把 probe side-condition 的过滤结果 OR 进 `row_filter_map` +4. probe 时: + - `row_filter_map[i] = 1` 的行继续走历史 unmatched 路径 + - `row_filter_map[i] = 0` 的行真正进入 hash probe + +这条规则对 outer join 的影响是: + +- `LEFT OUTER JOIN` 不会再把 NullEQ 的 `NULL` probe 行过早打成 unmatched +- `FULL OUTER JOIN` 同理,但还要与后扫 build unmatched 语义一起对齐 + +## 5. Hash key 编码策略 + +### 当前实现 + +当存在 **nullable 的 NullEQ key** 时,当前 Join map method 的选择规则是: + +- 若参与编码的 key columns 都是 fixed-size,且 `null bitmap + payload` 能放进 `UInt128/UInt256` + - 分别走 `JoinMapMethod::nullable_keys128` / `JoinMapMethod::nullable_keys256` +- 其它情况继续回退到 `JoinMapMethod::serialized` + +fixed-size 路径复用了 HashAgg / Set 已有的 nullable packed keys 思路: + +- `keys128/keys256 + has_nullable_keys = true` +- 把 nullness bitmap 与 key payload 一起编码进 packed key + +这样常见的 nullable numeric / datetime NullEQ join 不必再一律退化到 `serialized`。 + +`serialized` 仍然保留为正确性兜底: + +- 变长 key 仍可自然保留 `ColumnNullable` 的 nullness +- fixed-size key 若带 bitmap 后放不进 `UInt256`,仍可继续工作 + +但这里有一个必须显式满足的前提: + +- `serialized` 只是在“当前列对象长什么样”这个层面保留 nullness +- 它不会自动把 `Nullable(T)` 与 `T` 归一成同一种物理编码 + +当前 `ColumnNullable::serializeValueIntoArena()` 会先写入 null flag,再写 nested value。 +因此对于同一个非空值: + +- `Nullable(Int32)` 的序列化结果 +- `Int32` 的序列化结果 + +并不相同。 + +这意味着: + +- 若某个 NullEQ key pair 一侧是 nullable、另一侧是 non-nullable +- 即使 build/probe 两边都走 `serialized` +- 只要两边最终 key schema 仍分别是 `Nullable(T)` 与 `T` +- 相同的非空值也可能 hash / probe 不命中 + +因此 MVP 不能只做“nullable NullEQ 强制 serialized”,还必须保证: + +- 对每个 `is_null_eq[i] = true` 的 key pair +- 只要任一侧最终需要保留 nullable 语义 +- build/probe 两边就必须在 prepare key 阶段对齐到同一个物理 key schema +- 最直接的做法是统一到 `Nullable(common_type)` + +这一步最初属于 `serialized` 正确性兜底的一部分;在 fixed-size packed key 优化落地后,这个 schema 对齐约束仍然需要继续保持。 + +## 6. JoinPartition / KeyGetter 语义 + +NullEQ 真正落地到 JoinPartition 时,关键不是“有没有 nullable 列”,而是: + +- key getter 能不能把 nullness 编进 key + +当前 JoinPartition 已显式引入 nullable-aware 的 fixed-key KeyGetter 分支: + +- `nullable_keys128 -> HashMethodKeysFixed<..., UInt128, ..., true, false>` +- `nullable_keys256 -> HashMethodKeysFixed<..., UInt256, ..., true, false>` + +对应语义是: + +- packed key 路径会把 nullness bitmap 编进 key +- 变长 key 或超出 `UInt256` 的 fixed-size key 仍走 `serialized` + +无论走哪条路径,仍要保证: + +- build/probe 传进来的 key columns 保留了 NullEQ key 的 nullable 信息 +- 对任意 NullEQ key pair,build/probe 两边最终参与编码的 key schema 一致 +- 尤其是 mixed nullable / non-nullable 的场景,不能保留成 `Nullable(T)` 对 `T` + +## 7. FULL + other condition 语义 + +由于 full 分支已经有 row-flagged 逻辑,NullEQ 这里的要求不是新增一套 full 语义,而是确保 NullEQ key 命中也走已有正确链路: + +1. 对 `full + other condition`: + - 继续使用 row-flagged map +2. 对 key 命中但 other condition 失败的场景: + - build 行的 used 标记必须延后到 other condition 通过后 +3. 这个规则必须覆盖: + - 普通值命中 + - `NULL <=> NULL` 命中 + +否则 full 下会出现漏右行或重复 unmatched 行。 + +## 8. RuntimeFilter + +MVP 建议: + +- 只要 join 含 NullEQ,且存在 nullable 的 NullEQ key,就禁用 runtime filter + +原因: + +- 当前 runtime filter / Set 路径仍然默认丢弃 `NULL` key +- 这与 NullEQ 的 “NULL 可以匹配” 冲突 + +更长期的方向可以是: + +- Set 里额外维护 `has_null` +- 单列 NullEQ key 的 runtime filter 应用语义改成: + - `isNull(x) ? has_null : (x IN set)` + +但这不建议放进 MVP。 + +## 9. 备选方案:Planner 重写 `<=>` + +另一条路线是让 TiDB planner 不显式下发 `is_null_eq[]`,而是把每个 `<=>` key 重写成: + +1. `isNull(k)` +2. `ifNull(k, sentinel)` + +这样 TiFlash 仍然走普通 `=` join。 + +这条路线的优点是执行层改动小,但缺点也很明显: + +- 每个 `<=>` key 变成两个 key +- hash key 变宽 +- planner/runtime filter/cast/collation 都会更绕 +- key 级别语义变得不够直观 + +因此本设计默认选择: + +- TiFlash 原生支持 key 粒度 NullEQ + +## 测试建议 + +MVP 至少应覆盖: + +1. `INNER JOIN` + - `NULL <=> NULL` 命中 + - `NULL <=> 1` 不命中 +2. `LEFT OUTER JOIN` + - probe 侧 `NULL` key 不会被过早当作 unmatched +3. `RIGHT OUTER JOIN` + - build 侧 `NULL` key 不会被错误塞进 `RowsNotInsertToMap` +4. `FULL OUTER JOIN` + - `NULL <=> NULL` 命中时,不会被拆成两条 unmatched + - `NULL <=> NULL` 命中但 `other condition` 失败时,左右 unmatched 都正确 +5. `SEMI / ANTI` + - `NULL <=> NULL` 参与存在性判断 +6. 多列混合语义 + - `k1 <=> k1 AND k2 = k2` + - `k1 <=> k1 AND k2 <=> k2` +7. side-condition 交互 + - left/right condition 与 NullEQ key 共存 +8. spill / fine-grained shuffle + - 特别是 `FULL OUTER JOIN + NullEQ` + +### CP3 测试补充进度 + +按当前 workspace 的进度,CP3 的 spill / FGS 链路已补齐,当前已覆盖: + +1. `spill + FULL OUTER JOIN + NullEQ` + - `NULL <=> NULL` 命中后不会被拆成两条 unmatched +2. `spill + FULL OUTER JOIN + NullEQ + other condition` + - 数据同时覆盖 `other condition = false/true` + - `other condition = false` 时,build 行仍会在 scan-after-probe 正确输出 + - `other condition = true` 时,build 行会被正常消费,不会再次输出 +3. `fine-grained shuffle + NullEQ` + - 覆盖了一组 nullable key + - 验证 build / probe 两侧 key schema 对齐后,probe 不会把 NullEQ 的 `NULL` 误判成 filtered / unmatched + +## 代码热点 + +- `dbms/src/Flash/Coprocessor/JoinInterpreterHelper.*` +- `dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp` +- `dbms/src/Interpreters/Join.h` +- `dbms/src/Interpreters/Join.cpp` +- `dbms/src/Interpreters/ProbeProcessInfo.cpp` +- `dbms/src/Interpreters/JoinPartition.cpp` +- `dbms/src/Interpreters/JoinHashMap.cpp` +- `dbms/src/DataStreams/ScanHashMapAfterProbeBlockInputStream.cpp` +- `dbms/src/Interpreters/Set.cpp` + +--- + +## 开发追踪 / Dev Note + +这一节放在文档后半部分,用于后续按 checkpoint 推进时记录实现进度。设计结论以前面的章节为准。 + +### How to continue + +继续开发前建议固定做三件事: + +1. 先读本文件的设计部分 +2. 再跑 `git status` / `git diff --stat` +3. 明确本次只推进哪个 checkpoint + +建议在后续指令里直接写: + +- “以 `docs/note/nulleq_join.md` 为准,从 CP2 开始继续” +- “先读设计文档,再读当前进度” + +### Milestone 划分 + +#### Milestone 0:协议 / Plumbing + +Done 标准: + +- TiFlash 能解析 `is_null_eq[]` +- 能透传到 `DB::Join` +- 未下发该字段时行为零变化 + +#### Milestone 1:正确性 MVP + +Done 标准: + +- nullable NullEQ key 能正确 build / probe +- mixed nullable / non-nullable 的 NullEQ key pair 能正确对齐 key schema 并命中 +- outer join / scan-after-probe 不把 NullEQ 的 `NULL` 行误判为 unmatched +- `FULL OUTER JOIN + other condition` 与 NullEQ 组合语义正确 +- runtime filter 在该模式下被禁用 + +#### Milestone 2:测试矩阵 + +Done 标准: + +- inner / left / right / full / semi / anti 的基础矩阵覆盖齐 +- mixed key、side-condition、spill/FGS 覆盖齐 + +#### Milestone 3:性能优化 + +Done 标准: + +- nullable fixed-size key 不再强制 serialized + +#### Milestone 4:RuntimeFilter(可选) + +Done 标准: + +- 单列 NullEQ key 的 runtime filter 语义正确,或明确长期禁用 + +### Checkpoint 建议 + +- CP0:tipb 字段 + TiFlash 解析 +- CP1:`DB::Join` 保存/打印 `is_null_eq` +- CP2.1:nullable NullEQ 强制 serialized + mixed-nullability key schema 对齐 + NullAware 互斥检查 +- CP2.2:build/probe 的 row_filter_map 语义拆分 +- CP2.3:`RowsNotInsertToMap` / scan-after-probe 调整 +- CP2.4:`FULL OUTER JOIN + other condition` 与 NullEQ 联动自测 +- CP2.5:MVP 禁用 runtime filter +- CP3:补测试 +- CP4:packed keys 优化 + +### 当前进度 + +- 说明:以下勾选按当前 workspace 核对,用于记录本轮开发推进状态。 +- [x] tipb: `Join.is_null_eq` 字段定义 +- [x] TiFlash: `JoinInterpreterHelper::TiFlashJoin` 解析 `is_null_eq[]` +- [x] TiFlash: `DB::Join` 保存/打印 `is_null_eq` +- [x] TiFlash: nullable NullEQ 强制 serialized + mixed-nullability key schema 对齐 + NullAware 互斥 fail-fast +- [x] TiFlash: build/probe 的 row_filter_map 语义拆分 +- [x] TiFlash: `RowsNotInsertToMap` / scan-after-probe 调整 +- [x] TiFlash: `FULL OUTER JOIN + other condition` 与 NullEQ 联动验证 +- [x] TiFlash: runtime filter 禁用 +- [x] TiFlash: gtest 已覆盖 inner / left / right / full / semi / anti 基础矩阵 +- [x] TiFlash: gtest 已覆盖 mixed key 与 side-condition 交互 +- [x] TiFlash: spill / fine-grained shuffle 测试覆盖 +- [x] TiFlash: packed keys 优化(nullable fixed-size NullEQ key 可走 `nullable_keys128/256`,其余场景回退 `serialized`) + +### Open Questions + +- TiDB / kvproto 何时同步 `is_null_eq[]` +- key 若未来允许表达式,`is_null_eq[i]` 如何稳定对齐 +- string + collation 的性能回退是否可接受 +- spill / FGS 场景下是否需要单独的 profile 或 debug 指标 +- NullAware join 是否永远与 NullEQ 互斥,还是未来要定义组合语义 diff --git a/tests/fullstack-test/expr/json_object.test b/tests/fullstack-test/expr/json_object.test new file mode 100644 index 00000000000..a7c1092b0be --- /dev/null +++ b/tests/fullstack-test/expr/json_object.test @@ -0,0 +1,74 @@ +# Copyright 2023 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mysql> drop table if exists test.t_json_object; +mysql> create table test.t_json_object(id int, k1 varchar(20), k2 varchar(20), v_int int, v_str varchar(20), v_json json); +mysql> alter table test.t_json_object set tiflash replica 1; +mysql> insert into test.t_json_object values + (1, 'b', 'a', 1, 'x', '{\"nested\":1}'), + (2, 'dup', 'dup', 2, 'last', '[1,2]'), + (3, 'c', 'b', null, null, '[]'), + (4, null, 'a', 4, 'boom', '{}'); #NO_UNESCAPE + +func> wait_table test t_json_object + +# TODO: re-enable this explain check after the TiDB-side explain change is merged. +# mysql> set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; desc format='brief' select id, json_object(k1, v_int, k2, v_str) as res from test.t_json_object where id in (1, 2) order by id; +# {#REGEXP}id\s+estRows\s+task\s+access object\s+operator info +# {#REGEXP}Sort\s+\d+\.\d+\s+root\s+.*test\.t_json_object\.id +# {#REGEXP}└─TableReader\s+\d+\.\d+\s+root\s+.*MppVersion: 2, data:ExchangeSender +# {#REGEXP}.*ExchangeSender\s+\d+\.\d+\s+mpp\[tiflash\]\s+.*ExchangeType: PassThrough +# {#REGEXP}.*Projection\s+\d+\.\d+\s+mpp\[tiflash\]\s+.*json_object\(test\.t_json_object\.k1, cast\(test\.t_json_object\.v_int, json BINARY\), test\.t_json_object\.k2, cast\(test\.t_json_object\.v_str, json BINARY\)\)->Column#\d+ +# {#REGEXP}.*TableFullScan\s+\d+\.\d+\s+mpp\[tiflash\]\s+table:t_json_object\s+pushed down filter:in\(test\.t_json_object\.id, 1, 2\), keep order:false, stats:pseudo + +# empty object +mysql> set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select json_object() from test.t_json_object where id = 1; ++---------------+ +| json_object() | ++---------------+ +| {} | ++---------------+ + +# mixed value types and key sorting +mysql> set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select id, json_object('b', v_int, 'a', v_str, 'c', v_json) as res from test.t_json_object where id in (1, 3) order by id; ++----+--------------------------------------+ +| id | res | ++----+--------------------------------------+ +| 1 | {"a": "x", "b": 1, "c": {"nested": 1}} | +| 3 | {"a": null, "b": null, "c": []} | ++----+--------------------------------------+ + +# dynamic key columns and duplicate keys +mysql> set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select id, json_object(k1, v_int, k2, v_str) as res from test.t_json_object where id in (1, 2) order by id; ++----+--------------------+ +| id | res | ++----+--------------------+ +| 1 | {"a": "x", "b": 1} | +| 2 | {"dup": "last"} | ++----+--------------------+ + +# SQL NULL value becomes JSON null +mysql> set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select json_object('obj', v_json, 'nil', null) as res from test.t_json_object where id = 1; ++-----------------------------------+ +| res | ++-----------------------------------+ +| {"nil": null, "obj": {"nested": 1}} | ++-----------------------------------+ + +# NULL key should fail at execution time +mysql> set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select json_object(k1, v_int, k2, v_str) from test.t_json_object where id = 4; +{#REGEXP}.*NULL member names.* + +# Clean up. +mysql> drop table if exists test.t_json_object;