47 const Executor* executor) {
50 geo_func_finder.
visit(qual);
54 if (inner_table_key.table_id != -1 && outer_table_key.table_id != -1) {
57 CHECK_NE(outer_table_key, inner_table_key);
58 const auto inner_table_metadata =
60 const auto outer_table_metadata =
63 if (inner_table_metadata->fragmenter && outer_table_metadata->fragmenter) {
64 const auto inner_table_cardinality =
65 inner_table_metadata->fragmenter->getNumRows();
66 const auto outer_table_cardinality =
67 outer_table_metadata->fragmenter->getNumRows();
68 auto inner_qual_decision = inner_table_cardinality > outer_table_cardinality
77 const auto inner_cv_it =
78 std::find_if(geo_args.begin(),
81 return cv->getTableKey() == inner_table_key;
83 CHECK(inner_cv_it != geo_args.end());
84 const auto outer_cv_it =
85 std::find_if(geo_args.begin(),
88 return cv->getTableKey() == outer_table_key;
90 CHECK(outer_cv_it != geo_args.end());
91 const auto inner_cv = *inner_cv_it;
92 bool needs_table_reordering = inner_table_cardinality < outer_table_cardinality;
93 const auto outer_inner_card_ratio =
94 outer_table_cardinality /
static_cast<double>(inner_table_cardinality);
96 target_geo_func_name) ||
98 target_geo_func_name)) {
104 if (inner_cv->get_rte_idx() == 0 &&
105 (inner_cv->get_type_info().get_type() ==
kPOINT)) {
107 if (needs_table_reordering && outer_inner_card_ratio > 10.0 &&
108 inner_table_cardinality < 10000) {
122 if (needs_table_reordering) {
138 VLOG(2) <<
"Detect geo join operator, initial_inner_table(db_id: "
139 << inner_table_key.db_id <<
", table_id: " << inner_table_key.table_id
140 <<
"), cardinality: " << inner_table_cardinality
141 <<
"), initial_outer_table(db_id: " << outer_table_key.db_id
142 <<
", table_id: " << outer_table_key.table_id
143 <<
"), cardinality: " << outer_table_cardinality
144 <<
"), inner_qual_decision: " << inner_qual_decision;
145 return {200, 200, inner_qual_decision};
151 std::vector<SQLTypes> geo_types_for_func;
152 for (
size_t i = 0; i < func_oper->getArity(); i++) {
153 const auto arg_expr = func_oper->
getArg(i);
156 geo_types_for_func.push_back(ti.get_type());
159 std::regex geo_func_regex(
"ST_[\\w]*");
160 std::smatch geo_func_match;
161 const auto& func_name = func_oper->getName();
162 if (geo_types_for_func.size() == 2 &&
163 std::regex_match(func_name, geo_func_match, geo_func_regex)) {
180 const auto normalized_bin_oper =
182 const auto& inner_outer = normalized_bin_oper.first;
184 auto lhs = bin_oper->get_left_operand();
185 if (
auto lhs_tuple = dynamic_cast<const Analyzer::ExpressionTuple*>(
186 bin_oper->get_left_operand())) {
187 lhs = lhs_tuple->getTuple().front().get();
190 if (lhs == inner_outer.front().first) {
192 }
else if (lhs == inner_outer.front().second) {
198 return {200, 200, inner_qual_decision};
201 return {100, 100, inner_qual_decision};
207 const std::vector<InputTableInfo>& table_infos,
208 const Executor* executor,
209 std::vector<std::map<node_t, InnerQualDecision>>& qual_detection_res) {
210 CHECK_EQ(left_deep_join_quals.size() + 1, table_infos.size());
211 std::vector<std::map<node_t, cost_t>> join_cost_graph(table_infos.size());
215 for (
const auto& current_level_join_conditions : left_deep_join_quals) {
216 for (
const auto& qual : current_level_join_conditions.quals) {
217 std::set<int> qual_nest_levels = visitor.
visit(qual.get());
218 if (qual_nest_levels.size() != 2) {
221 int lhs_nest_level = *qual_nest_levels.begin();
223 qual_nest_levels.erase(qual_nest_levels.begin());
224 int rhs_nest_level = *qual_nest_levels.begin();
229 qual_detection_res[lhs_nest_level][rhs_nest_level] = std::get<2>(qual_costing);
230 qual_detection_res[rhs_nest_level][lhs_nest_level] = std::get<2>(qual_costing);
231 const auto edge_it = join_cost_graph[lhs_nest_level].find(rhs_nest_level);
232 auto rhs_cost = std::get<1>(qual_costing);
233 if (edge_it == join_cost_graph[lhs_nest_level].end() ||
234 edge_it->second > rhs_cost) {
235 auto lhs_cost = std::get<0>(qual_costing);
236 join_cost_graph[lhs_nest_level][rhs_nest_level] = rhs_cost;
237 join_cost_graph[rhs_nest_level][lhs_nest_level] = lhs_cost;
241 return join_cost_graph;
254 for (
auto& inbound_for_node : inbound_) {
255 inbound_for_node.erase(from);
261 std::unordered_set<node_t> roots;
262 for (
node_t candidate = 0; candidate < inbound_.size(); ++candidate) {
263 if (inbound_[candidate].empty()) {
264 roots.insert(candidate);
283 const std::vector<std::map<node_t, cost_t>>& join_cost_graph) {
289 for (
size_t level_idx = 0; level_idx < left_deep_join_quals.size(); ++level_idx) {
291 dependency_tracking.
addEdge(level_idx, level_idx + 1);
294 return dependency_tracking;
301 const std::vector<std::map<node_t, cost_t>>& join_cost_graph,
302 const std::vector<InputTableInfo>& table_infos,
303 const std::function<
bool(
const node_t lhs_nest_level,
const node_t rhs_nest_level)>&
307 std::vector<std::map<node_t, InnerQualDecision>>& qual_normalization_res) {
308 std::vector<node_t> all_nest_levels(table_infos.size());
309 std::iota(all_nest_levels.begin(), all_nest_levels.end(), 0);
310 std::vector<node_t> input_permutation;
311 std::unordered_set<node_t> visited;
312 auto dependency_tracking =
314 auto schedulable_node = [&dependency_tracking, &visited](
const node_t node) {
315 const auto nodes_ready = dependency_tracking.getRoots();
316 return nodes_ready.find(node) != nodes_ready.end() &&
317 visited.find(node) == visited.end();
319 while (visited.size() < table_infos.size()) {
321 std::vector<node_t> remaining_nest_levels;
322 std::copy_if(all_nest_levels.begin(),
323 all_nest_levels.end(),
324 std::back_inserter(remaining_nest_levels),
326 CHECK(!remaining_nest_levels.empty());
328 const auto start_it = std::max_element(
329 remaining_nest_levels.begin(), remaining_nest_levels.end(), compare_node);
330 CHECK(start_it != remaining_nest_levels.end());
331 std::priority_queue<TraversalEdge, std::vector<TraversalEdge>, decltype(compare_edge)>
332 worklist(compare_edge);
343 if (remaining_nest_levels.size() == 2 && qual_normalization_res[start].size() == 1) {
344 auto inner_qual_decision = qual_normalization_res[start].begin()->second;
345 auto join_qual = left_deep_join_quals.begin()->quals;
348 bool (*)(
const Analyzer::ColumnVar*,
const Analyzer::ColumnVar*)>;
350 auto set_new_rte_idx = [](ColvarSet& cv_set,
int new_rte) {
352 cv_set.begin(), cv_set.end(), [new_rte](
const Analyzer::ColumnVar* cv) {
353 const_cast<Analyzer::ColumnVar*
>(cv)->set_rte_idx(new_rte);
363 auto analyze_join_qual = [&start,
364 &remaining_nest_levels,
365 &inner_qual_decision,
367 compare_node](
const std::shared_ptr<Analyzer::Expr>& lhs,
368 ColvarSet& lhs_colvar_set,
369 const std::shared_ptr<Analyzer::Expr>& rhs,
370 ColvarSet& rhs_colvar_set) {
371 if (!lhs || !rhs || lhs_colvar_set.empty() || rhs_colvar_set.empty()) {
372 return std::make_pair(Decision::IGNORE, start);
375 auto alternative_it = std::find_if(
376 remaining_nest_levels.begin(),
377 remaining_nest_levels.end(),
378 [start](
const size_t nest_level) {
return start != nest_level; });
379 CHECK(alternative_it != remaining_nest_levels.end());
380 auto alternative_rte = *alternative_it;
382 Decision decision = Decision::IGNORE;
386 bool is_outer_col_valid =
false;
387 auto check_expr_is_valid_col = [&is_outer_col_valid](
const Analyzer::Expr* expr) {
388 if (
auto expr_tuple = dynamic_cast<const Analyzer::ExpressionTuple*>(expr)) {
389 for (
auto& inner_expr : expr_tuple->getTuple()) {
391 HashJoin::getHashJoinColumn<Analyzer::ColumnVar>(inner_expr.get());
393 is_outer_col_valid =
false;
398 auto cv_from_expr = HashJoin::getHashJoinColumn<Analyzer::ColumnVar>(expr);
400 is_outer_col_valid =
false;
404 is_outer_col_valid =
true;
407 inner_rte = (*lhs_colvar_set.begin())->get_rte_idx();
408 outer_rte = (*rhs_colvar_set.begin())->get_rte_idx();
409 check_expr_is_valid_col(rhs.get());
411 inner_rte = (*rhs_colvar_set.begin())->get_rte_idx();
412 outer_rte = (*lhs_colvar_set.begin())->get_rte_idx();
413 check_expr_is_valid_col(lhs.get());
415 if (inner_rte >= 0 && outer_rte >= 0) {
416 const auto inner_cardinality =
417 table_infos[inner_rte].info.getNumTuplesUpperBound();
418 const auto outer_cardinality =
419 table_infos[outer_rte].info.getNumTuplesUpperBound();
421 if (inner_rte == static_cast<int>(start)) {
426 decision = is_outer_col_valid && inner_cardinality > outer_cardinality
430 CHECK_EQ(inner_rte, static_cast<int>(alternative_rte));
432 if (compare_node(inner_rte, start)) {
435 decision = Decision::IGNORE;
439 decision = Decision::KEEP;
445 if (decision == Decision::KEEP) {
446 return std::make_pair(decision, start);
447 }
else if (decision == Decision::SWAP) {
448 return std::make_pair(decision, alternative_rte);
450 return std::make_pair(Decision::IGNORE, start);
453 auto collect_colvars = [](
const std::shared_ptr<Analyzer::Expr> expr,
455 expr->collect_column_var(cv_set,
false);
458 auto adjust_reordering_logic = [&start, &start_edge, &start_it, set_new_rte_idx](
461 ColvarSet& lhs_colvar_set,
462 ColvarSet& rhs_colvar_set) {
463 CHECK(decision == Decision::SWAP);
464 start = alternative_rte;
465 set_new_rte_idx(lhs_colvar_set, alternative_rte);
466 set_new_rte_idx(rhs_colvar_set, *start_it);
467 start_edge.join_cost = 0;
468 start_edge.nest_level = start;
474 auto rhs = bin_op->get_own_right_operand();
475 if (
auto lhs_exp = dynamic_cast<Analyzer::ExpressionTuple*>(lhs.get())) {
480 auto& lhs_exprs = lhs_exp->getTuple();
481 auto& rhs_exprs = rhs_exp->getTuple();
482 CHECK_EQ(lhs_exprs.size(), rhs_exprs.size());
483 for (
size_t i = 0; i < lhs_exprs.size(); ++i) {
484 Decision decision{Decision::IGNORE};
485 int alternative_rte_idx = -1;
488 collect_colvars(lhs_exprs.at(i), lhs_colvar_set);
489 collect_colvars(rhs_exprs.at(i), rhs_colvar_set);
491 auto investigation_res =
492 analyze_join_qual(lhs, lhs_colvar_set, rhs, rhs_colvar_set);
493 decision = investigation_res.first;
494 if (decision == Decision::KEEP) {
495 return remaining_nest_levels;
497 alternative_rte_idx = investigation_res.second;
499 if (decision == Decision::SWAP) {
500 adjust_reordering_logic(
501 decision, alternative_rte_idx, lhs_colvar_set, rhs_colvar_set);
507 collect_colvars(lhs, lhs_colvar_set);
508 collect_colvars(rhs, rhs_colvar_set);
509 auto investigation_res =
510 analyze_join_qual(lhs, lhs_colvar_set, rhs, rhs_colvar_set);
511 if (investigation_res.first == Decision::KEEP) {
512 return remaining_nest_levels;
513 }
else if (investigation_res.first == Decision::SWAP) {
514 adjust_reordering_logic(investigation_res.first,
515 investigation_res.second,
523 VLOG(2) <<
"Table reordering starting with nest level " << start;
524 for (
const auto& graph_edge : join_cost_graph[*start_it]) {
525 const node_t succ = graph_edge.first;
526 if (!schedulable_node(succ)) {
530 for (
const auto& successor_edge : join_cost_graph[succ]) {
531 if (successor_edge.first == start) {
532 start_edge.
join_cost = successor_edge.second;
535 if (compare_edge(start_edge, succ_edge)) {
536 VLOG(2) <<
"Table reordering changing start nest level from " << start
539 start_edge = succ_edge;
544 VLOG(2) <<
"Table reordering picked start nest level " << start <<
" with cost "
545 << start_edge.join_cost;
546 CHECK_EQ(start, start_edge.nest_level);
547 worklist.push(start_edge);
548 const auto it_ok = visited.insert(start);
550 while (!worklist.empty()) {
554 VLOG(1) <<
"Insert input permutation, idx: " << input_permutation.size()
557 dependency_tracking.removeNode(crt.
nest_level);
559 for (
const auto& graph_edge : join_cost_graph[crt.
nest_level]) {
560 const node_t succ = graph_edge.first;
561 if (!schedulable_node(succ)) {
565 const auto it_ok = visited.insert(succ);
570 return input_permutation;
577 const std::vector<InputTableInfo>& table_infos,
578 const Executor* executor) {
579 std::vector<std::map<node_t, InnerQualDecision>> qual_normalization_res(
582 left_deep_join_quals, table_infos, executor, qual_normalization_res);
584 const auto compare_node = [&table_infos](
const node_t lhs_nest_level,
585 const node_t rhs_nest_level) {
586 return table_infos[lhs_nest_level].info.getNumTuplesUpperBound() <
587 table_infos[rhs_nest_level].info.getNumTuplesUpperBound();
589 const auto compare_edge = [&compare_node](
const TraversalEdge& lhs_edge,
590 const TraversalEdge& rhs_edge) {
592 if (lhs_edge.join_cost == rhs_edge.join_cost) {
593 return compare_node(lhs_edge.nest_level, rhs_edge.nest_level);
595 return lhs_edge.join_cost > rhs_edge.join_cost;
601 left_deep_join_quals,
602 qual_normalization_res);
std::unordered_set< node_t > getRoots() const
static bool colvar_comp(const ColumnVar *l, const ColumnVar *r)
#define IS_EQUIVALENCE(X)
SchedulingDependencyTracking build_dependency_tracking(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< std::map< node_t, cost_t >> &join_cost_graph)
bool is_constructed_point(const Analyzer::Expr *expr)
static std::unordered_map< SQLTypes, cost_t > std::tuple< cost_t, cost_t, InnerQualDecision > get_join_qual_cost(const Analyzer::Expr *qual, const Executor *executor)
const TableDescriptor * get_metadata_for_table(const ::shared::TableKey &table_key, bool populate_fragmenter)
const std::vector< const Analyzer::ColumnVar * > & getGeoArgCvs() const
std::vector< JoinCondition > JoinQualsPerNestingLevel
T visit(const Analyzer::Expr *expr) const
unsigned g_trivial_loop_join_threshold
std::vector< node_t > get_node_input_permutation(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< InputTableInfo > &table_infos, const Executor *executor)
static std::unordered_map< SQLTypes, cost_t > GEO_TYPE_COSTS
void removeNode(const node_t from)
std::vector< std::map< node_t, cost_t > > build_join_cost_graph(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< InputTableInfo > &table_infos, const Executor *executor, std::vector< std::map< node_t, InnerQualDecision >> &qual_detection_res)
const std::string & getGeoFunctionName() const
static bool is_poly_point_rewrite_target_func(std::string_view target_func_name)
const SQLTypeInfo & get_type_info() const
SchedulingDependencyTracking(const size_t node_count)
const Analyzer::Expr * getArg(const size_t i) const
DEVICE void iota(ARGS &&...args)
void addEdge(const node_t from, const node_t to)
std::vector< std::unordered_set< node_t > > inbound_
static std::pair< std::vector< InnerOuter >, std::vector< InnerOuterStringOpInfos > > normalizeColumnPairs(const Analyzer::BinOper *condition, const TemporaryTables *temporary_tables)
std::vector< node_t > traverse_join_cost_graph(const std::vector< std::map< node_t, cost_t >> &join_cost_graph, const std::vector< InputTableInfo > &table_infos, const std::function< bool(const node_t lhs_nest_level, const node_t rhs_nest_level)> &compare_node, const std::function< bool(const TraversalEdge &, const TraversalEdge &)> &compare_edge, const JoinQualsPerNestingLevel &left_deep_join_quals, std::vector< std::map< node_t, InnerQualDecision >> &qual_normalization_res)
AccessManager::Decision Decision
InnerQualDecision inner_qual_decision
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
static bool is_point_poly_rewrite_target_func(std::string_view target_func_name)
const std::pair< shared::TableKey, shared::TableKey > getTableIdsOfGeoExpr() const