44 const Executor* executor) {
47 geo_func_finder.
visit(qual);
49 auto const inner_table_key = (*table_key_pair).inner_table_key;
50 auto const outer_table_key = (*table_key_pair).outer_table_key;
53 CHECK_NE(inner_table_key, outer_table_key);
55 auto const inner_table_cardinality =
57 auto const outer_table_cardinality =
59 auto inner_qual_decision = inner_table_cardinality > outer_table_cardinality
68 const auto inner_cv_it =
69 std::find_if(geo_args.begin(),
72 return cv->getTableKey() == inner_table_key;
74 CHECK(inner_cv_it != geo_args.end());
75 const auto inner_cv = *inner_cv_it;
76 bool needs_table_reordering = inner_table_cardinality < outer_table_cardinality;
77 const auto outer_inner_card_ratio =
78 outer_table_cardinality /
static_cast<double>(inner_table_cardinality);
80 target_geo_func_name) ||
82 target_geo_func_name)) {
89 if (inner_cv->get_rte_idx() == 0 &&
90 (inner_cv->get_type_info().get_type() ==
kPOINT)) {
92 if (needs_table_reordering && outer_inner_card_ratio > 10.0 &&
93 inner_table_cardinality < 10000) {
107 if (needs_table_reordering) {
123 VLOG(2) <<
"Detect geo join operator, initial_inner_table(db_id: "
124 << inner_table_key.db_id <<
", table_id: " << inner_table_key.table_id
125 <<
"), cardinality: " << inner_table_cardinality
126 <<
"), initial_outer_table(db_id: " << outer_table_key.db_id
127 <<
", table_id: " << outer_table_key.table_id
128 <<
"), cardinality: " << outer_table_cardinality
129 <<
"), inner_qual_decision: " << inner_qual_decision;
130 return {200, 200, inner_qual_decision};
136 std::vector<SQLTypes> geo_types_for_func;
137 for (
size_t i = 0; i < func_oper->getArity(); i++) {
138 const auto arg_expr = func_oper->
getArg(i);
141 geo_types_for_func.push_back(ti.get_type());
144 std::regex geo_func_regex(
"ST_[\\w]*");
145 std::smatch geo_func_match;
146 const auto& func_name = func_oper->getName();
147 if (geo_types_for_func.size() == 2 &&
148 std::regex_match(func_name, geo_func_match, geo_func_regex)) {
165 const auto normalized_bin_oper =
167 const auto& inner_outer = normalized_bin_oper.first;
169 auto lhs = bin_oper->get_left_operand();
170 if (
auto lhs_tuple = dynamic_cast<const Analyzer::ExpressionTuple*>(
171 bin_oper->get_left_operand())) {
172 lhs = lhs_tuple->getTuple().front().get();
175 if (lhs == inner_outer.front().first) {
177 }
else if (lhs == inner_outer.front().second) {
183 return {200, 200, inner_qual_decision};
186 return {100, 100, inner_qual_decision};
192 const std::vector<InputTableInfo>& table_infos,
193 const Executor* executor,
194 std::vector<std::map<node_t, InnerQualDecision>>& qual_detection_res) {
195 CHECK_EQ(left_deep_join_quals.size() + 1, table_infos.size());
196 std::vector<std::map<node_t, cost_t>> join_cost_graph(table_infos.size());
200 for (
const auto& current_level_join_conditions : left_deep_join_quals) {
201 for (
const auto& qual : current_level_join_conditions.quals) {
202 std::set<int> qual_nest_levels = visitor.
visit(qual.get());
203 if (qual_nest_levels.size() != 2) {
206 int lhs_nest_level = *qual_nest_levels.begin();
208 qual_nest_levels.erase(qual_nest_levels.begin());
209 int rhs_nest_level = *qual_nest_levels.begin();
214 qual_detection_res[lhs_nest_level][rhs_nest_level] = std::get<2>(qual_costing);
215 qual_detection_res[rhs_nest_level][lhs_nest_level] = std::get<2>(qual_costing);
216 const auto edge_it = join_cost_graph[lhs_nest_level].find(rhs_nest_level);
217 auto rhs_cost = std::get<1>(qual_costing);
218 if (edge_it == join_cost_graph[lhs_nest_level].end() ||
219 edge_it->second > rhs_cost) {
220 auto lhs_cost = std::get<0>(qual_costing);
221 join_cost_graph[lhs_nest_level][rhs_nest_level] = rhs_cost;
222 join_cost_graph[rhs_nest_level][lhs_nest_level] = lhs_cost;
226 return join_cost_graph;
239 for (
auto& inbound_for_node : inbound_) {
240 inbound_for_node.erase(from);
246 std::unordered_set<node_t> roots;
247 for (
node_t candidate = 0; candidate < inbound_.size(); ++candidate) {
248 if (inbound_[candidate].empty()) {
249 roots.insert(candidate);
268 const std::vector<std::map<node_t, cost_t>>& join_cost_graph) {
274 for (
size_t level_idx = 0; level_idx < left_deep_join_quals.size(); ++level_idx) {
276 dependency_tracking.
addEdge(level_idx, level_idx + 1);
279 return dependency_tracking;
286 const std::vector<std::map<node_t, cost_t>>& join_cost_graph,
287 const std::vector<InputTableInfo>& table_infos,
288 const std::function<
bool(
const node_t lhs_nest_level,
const node_t rhs_nest_level)>&
292 std::vector<std::map<node_t, InnerQualDecision>>& qual_normalization_res) {
293 std::vector<node_t> all_nest_levels(table_infos.size());
294 std::iota(all_nest_levels.begin(), all_nest_levels.end(), 0);
295 std::vector<node_t> input_permutation;
296 std::unordered_set<node_t> visited;
297 auto dependency_tracking =
299 auto schedulable_node = [&dependency_tracking, &visited](
const node_t node) {
300 const auto nodes_ready = dependency_tracking.getRoots();
301 return nodes_ready.find(node) != nodes_ready.end() &&
302 visited.find(node) == visited.end();
304 while (visited.size() < table_infos.size()) {
306 std::vector<node_t> remaining_nest_levels;
307 std::copy_if(all_nest_levels.begin(),
308 all_nest_levels.end(),
309 std::back_inserter(remaining_nest_levels),
311 CHECK(!remaining_nest_levels.empty());
313 const auto start_it = std::max_element(
314 remaining_nest_levels.begin(), remaining_nest_levels.end(), compare_node);
315 CHECK(start_it != remaining_nest_levels.end());
316 std::priority_queue<TraversalEdge, std::vector<TraversalEdge>, decltype(compare_edge)>
317 worklist(compare_edge);
328 if (remaining_nest_levels.size() == 2 && qual_normalization_res[start].size() == 1) {
329 auto inner_qual_decision = qual_normalization_res[start].begin()->second;
330 auto join_qual = left_deep_join_quals.begin()->quals;
333 bool (*)(
const Analyzer::ColumnVar*,
const Analyzer::ColumnVar*)>;
335 auto set_new_rte_idx = [](ColvarSet& cv_set,
int new_rte) {
337 cv_set.begin(), cv_set.end(), [new_rte](
const Analyzer::ColumnVar* cv) {
338 const_cast<Analyzer::ColumnVar*
>(cv)->set_rte_idx(new_rte);
348 auto analyze_join_qual = [&start,
349 &remaining_nest_levels,
350 &inner_qual_decision,
352 compare_node](
const std::shared_ptr<Analyzer::Expr>& lhs,
353 ColvarSet& lhs_colvar_set,
354 const std::shared_ptr<Analyzer::Expr>& rhs,
355 ColvarSet& rhs_colvar_set) {
356 if (!lhs || !rhs || lhs_colvar_set.empty() || rhs_colvar_set.empty()) {
357 return std::make_pair(Decision::IGNORE, start);
360 auto alternative_it = std::find_if(
361 remaining_nest_levels.begin(),
362 remaining_nest_levels.end(),
363 [start](
const size_t nest_level) {
return start != nest_level; });
364 CHECK(alternative_it != remaining_nest_levels.end());
365 auto alternative_rte = *alternative_it;
367 Decision decision = Decision::IGNORE;
371 bool is_outer_col_valid =
false;
372 auto check_expr_is_valid_col = [&is_outer_col_valid](
const Analyzer::Expr* expr) {
373 if (
auto expr_tuple = dynamic_cast<const Analyzer::ExpressionTuple*>(expr)) {
374 for (
auto& inner_expr : expr_tuple->getTuple()) {
376 HashJoin::getHashJoinColumn<Analyzer::ColumnVar>(inner_expr.get());
378 is_outer_col_valid =
false;
383 auto cv_from_expr = HashJoin::getHashJoinColumn<Analyzer::ColumnVar>(expr);
385 is_outer_col_valid =
false;
389 is_outer_col_valid =
true;
392 inner_rte = (*lhs_colvar_set.begin())->get_rte_idx();
393 outer_rte = (*rhs_colvar_set.begin())->get_rte_idx();
394 check_expr_is_valid_col(rhs.get());
396 inner_rte = (*rhs_colvar_set.begin())->get_rte_idx();
397 outer_rte = (*lhs_colvar_set.begin())->get_rte_idx();
398 check_expr_is_valid_col(lhs.get());
400 if (inner_rte >= 0 && outer_rte >= 0) {
401 const auto inner_cardinality =
402 table_infos[inner_rte].info.getNumTuplesUpperBound();
403 const auto outer_cardinality =
404 table_infos[outer_rte].info.getNumTuplesUpperBound();
406 if (inner_rte == static_cast<int>(start)) {
411 decision = is_outer_col_valid && inner_cardinality > outer_cardinality
415 CHECK_EQ(inner_rte, static_cast<int>(alternative_rte));
417 if (compare_node(inner_rte, start)) {
420 decision = Decision::IGNORE;
424 decision = Decision::KEEP;
430 if (decision == Decision::KEEP) {
431 return std::make_pair(decision, start);
432 }
else if (decision == Decision::SWAP) {
433 return std::make_pair(decision, alternative_rte);
435 return std::make_pair(Decision::IGNORE, start);
438 auto collect_colvars = [](
const std::shared_ptr<Analyzer::Expr> expr,
440 expr->collect_column_var(cv_set,
false);
443 auto adjust_reordering_logic = [&start, &start_edge, &start_it, set_new_rte_idx](
446 ColvarSet& lhs_colvar_set,
447 ColvarSet& rhs_colvar_set) {
448 CHECK(decision == Decision::SWAP);
449 start = alternative_rte;
450 set_new_rte_idx(lhs_colvar_set, alternative_rte);
451 set_new_rte_idx(rhs_colvar_set, *start_it);
452 start_edge.join_cost = 0;
453 start_edge.nest_level = start;
459 auto rhs = bin_op->get_own_right_operand();
460 if (
auto lhs_exp = dynamic_cast<Analyzer::ExpressionTuple*>(lhs.get())) {
465 auto& lhs_exprs = lhs_exp->getTuple();
466 auto& rhs_exprs = rhs_exp->getTuple();
467 CHECK_EQ(lhs_exprs.size(), rhs_exprs.size());
468 for (
size_t i = 0; i < lhs_exprs.size(); ++i) {
469 Decision decision{Decision::IGNORE};
470 int alternative_rte_idx = -1;
473 collect_colvars(lhs_exprs.at(i), lhs_colvar_set);
474 collect_colvars(rhs_exprs.at(i), rhs_colvar_set);
476 auto investigation_res =
477 analyze_join_qual(lhs, lhs_colvar_set, rhs, rhs_colvar_set);
478 decision = investigation_res.first;
479 if (decision == Decision::KEEP) {
480 return remaining_nest_levels;
482 alternative_rte_idx = investigation_res.second;
484 if (decision == Decision::SWAP) {
485 adjust_reordering_logic(
486 decision, alternative_rte_idx, lhs_colvar_set, rhs_colvar_set);
492 collect_colvars(lhs, lhs_colvar_set);
493 collect_colvars(rhs, rhs_colvar_set);
494 auto investigation_res =
495 analyze_join_qual(lhs, lhs_colvar_set, rhs, rhs_colvar_set);
496 if (investigation_res.first == Decision::KEEP) {
497 return remaining_nest_levels;
498 }
else if (investigation_res.first == Decision::SWAP) {
499 adjust_reordering_logic(investigation_res.first,
500 investigation_res.second,
508 VLOG(2) <<
"Table reordering starting with nest level " << start;
509 for (
const auto& graph_edge : join_cost_graph[*start_it]) {
510 const node_t succ = graph_edge.first;
511 if (!schedulable_node(succ)) {
515 for (
const auto& successor_edge : join_cost_graph[succ]) {
516 if (successor_edge.first == start) {
517 start_edge.
join_cost = successor_edge.second;
520 if (compare_edge(start_edge, succ_edge)) {
521 VLOG(2) <<
"Table reordering changing start nest level from " << start
524 start_edge = succ_edge;
529 VLOG(2) <<
"Table reordering picked start nest level " << start <<
" with cost "
530 << start_edge.join_cost;
531 CHECK_EQ(start, start_edge.nest_level);
532 worklist.push(start_edge);
533 const auto it_ok = visited.insert(start);
535 while (!worklist.empty()) {
539 VLOG(1) <<
"Insert input permutation, idx: " << input_permutation.size()
542 dependency_tracking.removeNode(crt.
nest_level);
544 for (
const auto& graph_edge : join_cost_graph[crt.
nest_level]) {
545 const node_t succ = graph_edge.first;
546 if (!schedulable_node(succ)) {
550 const auto it_ok = visited.insert(succ);
555 return input_permutation;
562 const std::vector<InputTableInfo>& table_infos,
563 const Executor* executor) {
564 std::vector<std::map<node_t, InnerQualDecision>> qual_normalization_res(
567 left_deep_join_quals, table_infos, executor, qual_normalization_res);
569 const auto compare_node = [&table_infos](
const node_t lhs_nest_level,
570 const node_t rhs_nest_level) {
571 return table_infos[lhs_nest_level].info.getNumTuplesUpperBound() <
572 table_infos[rhs_nest_level].info.getNumTuplesUpperBound();
574 const auto compare_edge = [&compare_node](
const TraversalEdge& lhs_edge,
575 const TraversalEdge& rhs_edge) {
577 if (lhs_edge.join_cost == rhs_edge.join_cost) {
578 return compare_node(lhs_edge.nest_level, rhs_edge.nest_level);
580 return lhs_edge.join_cost > rhs_edge.join_cost;
586 left_deep_join_quals,
587 qual_normalization_res);
static bool is_point_poly_rewrite_target_func(std::string_view target_func_name)
std::unordered_set< node_t > getRoots() const
static bool colvar_comp(const ColumnVar *l, const ColumnVar *r)
#define IS_EQUIVALENCE(X)
SchedulingDependencyTracking build_dependency_tracking(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< std::map< node_t, cost_t >> &join_cost_graph)
bool is_constructed_point(const Analyzer::Expr *expr)
static std::unordered_map< SQLTypes, cost_t > std::tuple< cost_t, cost_t, InnerQualDecision > get_join_qual_cost(const Analyzer::Expr *qual, const Executor *executor)
size_t get_table_cardinality(shared::TableKey const &table_key, Executor const *executor)
const std::vector< const Analyzer::ColumnVar * > & getGeoArgCvs() const
std::vector< JoinCondition > JoinQualsPerNestingLevel
T visit(const Analyzer::Expr *expr) const
unsigned g_trivial_loop_join_threshold
std::vector< node_t > get_node_input_permutation(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< InputTableInfo > &table_infos, const Executor *executor)
static std::unordered_map< SQLTypes, cost_t > GEO_TYPE_COSTS
void removeNode(const node_t from)
std::vector< std::map< node_t, cost_t > > build_join_cost_graph(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< InputTableInfo > &table_infos, const Executor *executor, std::vector< std::map< node_t, InnerQualDecision >> &qual_detection_res)
const std::string & getGeoFunctionName() const
const SQLTypeInfo & get_type_info() const
SchedulingDependencyTracking(const size_t node_count)
const Analyzer::Expr * getArg(const size_t i) const
DEVICE void iota(ARGS &&...args)
void addEdge(const node_t from, const node_t to)
const std::optional< GeoJoinOperandsTableKeyPair > getJoinTableKeyPair() const
std::vector< std::unordered_set< node_t > > inbound_
static std::pair< std::vector< InnerOuter >, std::vector< InnerOuterStringOpInfos > > normalizeColumnPairs(const Analyzer::BinOper *condition, const TemporaryTables *temporary_tables)
static bool is_poly_point_rewrite_target_func(std::string_view target_func_name)
std::vector< node_t > traverse_join_cost_graph(const std::vector< std::map< node_t, cost_t >> &join_cost_graph, const std::vector< InputTableInfo > &table_infos, const std::function< bool(const node_t lhs_nest_level, const node_t rhs_nest_level)> &compare_node, const std::function< bool(const TraversalEdge &, const TraversalEdge &)> &compare_edge, const JoinQualsPerNestingLevel &left_deep_join_quals, std::vector< std::map< node_t, InnerQualDecision >> &qual_normalization_res)
AccessManager::Decision Decision
InnerQualDecision inner_qual_decision
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const