47 const Executor* executor) {
50 geo_func_finder.
visit(qual);
51 int inner_table_id = -1;
52 int outer_table_id = -1;
54 if (inner_table_id != -1 && outer_table_id != -1) {
57 CHECK_NE(outer_table_id, inner_table_id);
58 const auto cat = executor->getCatalog();
59 const auto inner_table_metadata =
cat->getMetadataForTable(inner_table_id);
60 const auto outer_table_metadata =
cat->getMetadataForTable(outer_table_id);
62 if (inner_table_metadata->fragmenter && outer_table_metadata->fragmenter) {
63 const auto inner_table_cardinality =
64 inner_table_metadata->fragmenter->getNumRows();
65 const auto outer_table_cardinality =
66 outer_table_metadata->fragmenter->getNumRows();
67 auto inner_qual_decision = inner_table_cardinality > outer_table_cardinality
76 const auto inner_cv_it =
77 std::find_if(geo_args.begin(),
80 return cv->get_table_id() == inner_table_id;
82 CHECK(inner_cv_it != geo_args.end());
83 const auto outer_cv_it =
84 std::find_if(geo_args.begin(),
87 return cv->get_table_id() == outer_table_id;
89 CHECK(outer_cv_it != geo_args.end());
90 const auto inner_cv = *inner_cv_it;
91 bool needs_table_reordering = inner_table_cardinality < outer_table_cardinality;
92 const auto outer_inner_card_ratio =
93 outer_table_cardinality /
static_cast<double>(inner_table_cardinality);
95 target_geo_func_name) ||
97 target_geo_func_name)) {
103 if (inner_cv->get_rte_idx() == 0 &&
104 (inner_cv->get_type_info().get_type() ==
kPOINT)) {
106 if (needs_table_reordering && outer_inner_card_ratio > 10.0 &&
107 inner_table_cardinality < 10000) {
121 if (needs_table_reordering) {
137 VLOG(2) <<
"Detect geo join operator, initial_inner_table(table_id: "
138 << inner_table_id <<
", cardinality: " << inner_table_cardinality
139 <<
"), initial_outer_table(table_id: " << outer_table_id
140 <<
", cardinality: " << outer_table_cardinality
141 <<
"), inner_qual_decision: " << inner_qual_decision;
142 return {200, 200, inner_qual_decision};
148 std::vector<SQLTypes> geo_types_for_func;
149 for (
size_t i = 0; i < func_oper->getArity(); i++) {
150 const auto arg_expr = func_oper->
getArg(i);
153 geo_types_for_func.push_back(ti.get_type());
156 std::regex geo_func_regex(
"ST_[\\w]*");
157 std::smatch geo_func_match;
158 const auto& func_name = func_oper->getName();
159 if (geo_types_for_func.size() == 2 &&
160 std::regex_match(func_name, geo_func_match, geo_func_regex)) {
178 bin_oper, *executor->getCatalog(), executor->getTemporaryTables());
179 const auto& inner_outer = normalized_bin_oper.first;
181 auto lhs = bin_oper->get_left_operand();
182 if (
auto lhs_tuple = dynamic_cast<const Analyzer::ExpressionTuple*>(
183 bin_oper->get_left_operand())) {
184 lhs = lhs_tuple->getTuple().front().get();
187 if (lhs == inner_outer.front().first) {
189 }
else if (lhs == inner_outer.front().second) {
195 return {200, 200, inner_qual_decision};
198 return {100, 100, inner_qual_decision};
204 const std::vector<InputTableInfo>& table_infos,
205 const Executor* executor,
206 std::vector<std::map<node_t, InnerQualDecision>>& qual_detection_res) {
207 CHECK_EQ(left_deep_join_quals.size() + 1, table_infos.size());
208 std::vector<std::map<node_t, cost_t>> join_cost_graph(table_infos.size());
212 for (
const auto& current_level_join_conditions : left_deep_join_quals) {
213 for (
const auto& qual : current_level_join_conditions.quals) {
214 std::set<int> qual_nest_levels = visitor.
visit(qual.get());
215 if (qual_nest_levels.size() != 2) {
218 int lhs_nest_level = *qual_nest_levels.begin();
220 qual_nest_levels.erase(qual_nest_levels.begin());
221 int rhs_nest_level = *qual_nest_levels.begin();
226 qual_detection_res[lhs_nest_level][rhs_nest_level] = std::get<2>(qual_costing);
227 qual_detection_res[rhs_nest_level][lhs_nest_level] = std::get<2>(qual_costing);
228 const auto edge_it = join_cost_graph[lhs_nest_level].find(rhs_nest_level);
229 auto rhs_cost = std::get<1>(qual_costing);
230 if (edge_it == join_cost_graph[lhs_nest_level].end() ||
231 edge_it->second > rhs_cost) {
232 auto lhs_cost = std::get<0>(qual_costing);
233 join_cost_graph[lhs_nest_level][rhs_nest_level] = rhs_cost;
234 join_cost_graph[rhs_nest_level][lhs_nest_level] = lhs_cost;
238 return join_cost_graph;
251 for (
auto& inbound_for_node : inbound_) {
252 inbound_for_node.erase(from);
258 std::unordered_set<node_t> roots;
259 for (
node_t candidate = 0; candidate < inbound_.size(); ++candidate) {
260 if (inbound_[candidate].empty()) {
261 roots.insert(candidate);
280 const std::vector<std::map<node_t, cost_t>>& join_cost_graph) {
286 for (
size_t level_idx = 0; level_idx < left_deep_join_quals.size(); ++level_idx) {
288 dependency_tracking.
addEdge(level_idx, level_idx + 1);
291 return dependency_tracking;
298 const std::vector<std::map<node_t, cost_t>>& join_cost_graph,
299 const std::vector<InputTableInfo>& table_infos,
300 const std::function<
bool(
const node_t lhs_nest_level,
const node_t rhs_nest_level)>&
304 std::vector<std::map<node_t, InnerQualDecision>>& qual_normalization_res) {
305 std::vector<node_t> all_nest_levels(table_infos.size());
306 std::iota(all_nest_levels.begin(), all_nest_levels.end(), 0);
307 std::vector<node_t> input_permutation;
308 std::unordered_set<node_t> visited;
309 auto dependency_tracking =
311 auto schedulable_node = [&dependency_tracking, &visited](
const node_t node) {
312 const auto nodes_ready = dependency_tracking.getRoots();
313 return nodes_ready.find(node) != nodes_ready.end() &&
314 visited.find(node) == visited.end();
316 while (visited.size() < table_infos.size()) {
318 std::vector<node_t> remaining_nest_levels;
319 std::copy_if(all_nest_levels.begin(),
320 all_nest_levels.end(),
321 std::back_inserter(remaining_nest_levels),
323 CHECK(!remaining_nest_levels.empty());
325 const auto start_it = std::max_element(
326 remaining_nest_levels.begin(), remaining_nest_levels.end(), compare_node);
327 CHECK(start_it != remaining_nest_levels.end());
328 std::priority_queue<TraversalEdge, std::vector<TraversalEdge>, decltype(compare_edge)>
329 worklist(compare_edge);
340 if (remaining_nest_levels.size() == 2 && qual_normalization_res[start].size() == 1) {
341 auto inner_qual_decision = qual_normalization_res[start].begin()->second;
342 auto join_qual = left_deep_join_quals.begin()->quals;
345 bool (*)(
const Analyzer::ColumnVar*,
const Analyzer::ColumnVar*)>;
347 auto set_new_rte_idx = [](ColvarSet& cv_set,
int new_rte) {
349 cv_set.begin(), cv_set.end(), [new_rte](
const Analyzer::ColumnVar* cv) {
350 const_cast<Analyzer::ColumnVar*
>(cv)->set_rte_idx(new_rte);
360 auto analyze_join_qual = [&start,
361 &remaining_nest_levels,
362 &inner_qual_decision,
364 compare_node](
const std::shared_ptr<Analyzer::Expr>& lhs,
365 ColvarSet& lhs_colvar_set,
366 const std::shared_ptr<Analyzer::Expr>& rhs,
367 ColvarSet& rhs_colvar_set) {
368 if (!lhs || !rhs || lhs_colvar_set.empty() || rhs_colvar_set.empty()) {
369 return std::make_pair(Decision::IGNORE, start);
372 auto alternative_it = std::find_if(
373 remaining_nest_levels.begin(),
374 remaining_nest_levels.end(),
375 [start](
const size_t nest_level) {
return start != nest_level; });
376 CHECK(alternative_it != remaining_nest_levels.end());
377 auto alternative_rte = *alternative_it;
379 Decision decision = Decision::IGNORE;
383 bool is_outer_col_valid =
false;
384 auto check_expr_is_valid_col = [&is_outer_col_valid](
const Analyzer::Expr* expr) {
385 if (
auto expr_tuple = dynamic_cast<const Analyzer::ExpressionTuple*>(expr)) {
386 for (
auto& inner_expr : expr_tuple->getTuple()) {
388 HashJoin::getHashJoinColumn<Analyzer::ColumnVar>(inner_expr.get());
390 is_outer_col_valid =
false;
395 auto cv_from_expr = HashJoin::getHashJoinColumn<Analyzer::ColumnVar>(expr);
397 is_outer_col_valid =
false;
401 is_outer_col_valid =
true;
404 inner_rte = (*lhs_colvar_set.begin())->get_rte_idx();
405 outer_rte = (*rhs_colvar_set.begin())->get_rte_idx();
406 check_expr_is_valid_col(rhs.get());
408 inner_rte = (*rhs_colvar_set.begin())->get_rte_idx();
409 outer_rte = (*lhs_colvar_set.begin())->get_rte_idx();
410 check_expr_is_valid_col(lhs.get());
412 if (inner_rte >= 0 && outer_rte >= 0) {
413 const auto inner_cardinality =
414 table_infos[inner_rte].info.getNumTuplesUpperBound();
415 const auto outer_cardinality =
416 table_infos[outer_rte].info.getNumTuplesUpperBound();
418 if (inner_rte == static_cast<int>(start)) {
423 decision = is_outer_col_valid && inner_cardinality > outer_cardinality
427 CHECK_EQ(inner_rte, static_cast<int>(alternative_rte));
429 if (compare_node(inner_rte, start)) {
432 decision = Decision::IGNORE;
436 decision = Decision::KEEP;
442 if (decision == Decision::KEEP) {
443 return std::make_pair(decision, start);
444 }
else if (decision == Decision::SWAP) {
445 return std::make_pair(decision, alternative_rte);
447 return std::make_pair(Decision::IGNORE, start);
450 auto collect_colvars = [](
const std::shared_ptr<Analyzer::Expr> expr,
452 expr->collect_column_var(cv_set,
false);
455 auto adjust_reordering_logic = [&start, &start_edge, &start_it, set_new_rte_idx](
458 ColvarSet& lhs_colvar_set,
459 ColvarSet& rhs_colvar_set) {
460 CHECK(decision == Decision::SWAP);
461 start = alternative_rte;
462 set_new_rte_idx(lhs_colvar_set, alternative_rte);
463 set_new_rte_idx(rhs_colvar_set, *start_it);
464 start_edge.join_cost = 0;
465 start_edge.nest_level = start;
471 auto rhs = bin_op->get_own_right_operand();
472 if (
auto lhs_exp = dynamic_cast<Analyzer::ExpressionTuple*>(lhs.get())) {
477 auto& lhs_exprs = lhs_exp->getTuple();
478 auto& rhs_exprs = rhs_exp->getTuple();
479 CHECK_EQ(lhs_exprs.size(), rhs_exprs.size());
480 for (
size_t i = 0; i < lhs_exprs.size(); ++i) {
481 Decision decision{Decision::IGNORE};
482 int alternative_rte_idx = -1;
485 collect_colvars(lhs_exprs.at(i), lhs_colvar_set);
486 collect_colvars(rhs_exprs.at(i), rhs_colvar_set);
488 auto investigation_res =
489 analyze_join_qual(lhs, lhs_colvar_set, rhs, rhs_colvar_set);
490 decision = investigation_res.first;
491 if (decision == Decision::KEEP) {
492 return remaining_nest_levels;
494 alternative_rte_idx = investigation_res.second;
496 if (decision == Decision::SWAP) {
497 adjust_reordering_logic(
498 decision, alternative_rte_idx, lhs_colvar_set, rhs_colvar_set);
504 collect_colvars(lhs, lhs_colvar_set);
505 collect_colvars(rhs, rhs_colvar_set);
506 auto investigation_res =
507 analyze_join_qual(lhs, lhs_colvar_set, rhs, rhs_colvar_set);
508 if (investigation_res.first == Decision::KEEP) {
509 return remaining_nest_levels;
510 }
else if (investigation_res.first == Decision::SWAP) {
511 adjust_reordering_logic(investigation_res.first,
512 investigation_res.second,
520 VLOG(2) <<
"Table reordering starting with nest level " << start;
521 for (
const auto& graph_edge : join_cost_graph[*start_it]) {
522 const node_t succ = graph_edge.first;
523 if (!schedulable_node(succ)) {
527 for (
const auto& successor_edge : join_cost_graph[succ]) {
528 if (successor_edge.first == start) {
529 start_edge.
join_cost = successor_edge.second;
532 if (compare_edge(start_edge, succ_edge)) {
533 VLOG(2) <<
"Table reordering changing start nest level from " << start
536 start_edge = succ_edge;
541 VLOG(2) <<
"Table reordering picked start nest level " << start <<
" with cost "
542 << start_edge.join_cost;
543 CHECK_EQ(start, start_edge.nest_level);
544 worklist.push(start_edge);
545 const auto it_ok = visited.insert(start);
547 while (!worklist.empty()) {
551 VLOG(1) <<
"Insert input permutation, idx: " << input_permutation.size()
554 dependency_tracking.removeNode(crt.
nest_level);
556 for (
const auto& graph_edge : join_cost_graph[crt.
nest_level]) {
557 const node_t succ = graph_edge.first;
558 if (!schedulable_node(succ)) {
562 const auto it_ok = visited.insert(succ);
567 return input_permutation;
574 const std::vector<InputTableInfo>& table_infos,
575 const Executor* executor) {
576 std::vector<std::map<node_t, InnerQualDecision>> qual_normalization_res(
579 left_deep_join_quals, table_infos, executor, qual_normalization_res);
581 const auto compare_node = [&table_infos](
const node_t lhs_nest_level,
582 const node_t rhs_nest_level) {
583 return table_infos[lhs_nest_level].info.getNumTuplesUpperBound() <
584 table_infos[rhs_nest_level].info.getNumTuplesUpperBound();
586 const auto compare_edge = [&compare_node](
const TraversalEdge& lhs_edge,
587 const TraversalEdge& rhs_edge) {
589 if (lhs_edge.join_cost == rhs_edge.join_cost) {
590 return compare_node(lhs_edge.nest_level, rhs_edge.nest_level);
592 return lhs_edge.join_cost > rhs_edge.join_cost;
598 left_deep_join_quals,
599 qual_normalization_res);
std::unordered_set< node_t > getRoots() const
static bool colvar_comp(const ColumnVar *l, const ColumnVar *r)
#define IS_EQUIVALENCE(X)
SchedulingDependencyTracking build_dependency_tracking(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< std::map< node_t, cost_t >> &join_cost_graph)
static std::pair< std::vector< InnerOuter >, std::vector< InnerOuterStringOpInfos > > normalizeColumnPairs(const Analyzer::BinOper *condition, const Catalog_Namespace::Catalog &cat, const TemporaryTables *temporary_tables)
bool is_constructed_point(const Analyzer::Expr *expr)
static std::unordered_map< SQLTypes, cost_t > std::tuple< cost_t, cost_t, InnerQualDecision > get_join_qual_cost(const Analyzer::Expr *qual, const Executor *executor)
const std::vector< const Analyzer::ColumnVar * > & getGeoArgCvs() const
std::vector< JoinCondition > JoinQualsPerNestingLevel
T visit(const Analyzer::Expr *expr) const
unsigned g_trivial_loop_join_threshold
std::vector< node_t > get_node_input_permutation(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< InputTableInfo > &table_infos, const Executor *executor)
const std::pair< int, int > getTableIdsOfGeoExpr() const
static std::unordered_map< SQLTypes, cost_t > GEO_TYPE_COSTS
void removeNode(const node_t from)
std::vector< std::map< node_t, cost_t > > build_join_cost_graph(const JoinQualsPerNestingLevel &left_deep_join_quals, const std::vector< InputTableInfo > &table_infos, const Executor *executor, std::vector< std::map< node_t, InnerQualDecision >> &qual_detection_res)
const std::string & getGeoFunctionName() const
static bool is_poly_point_rewrite_target_func(std::string_view target_func_name)
const SQLTypeInfo & get_type_info() const
SchedulingDependencyTracking(const size_t node_count)
const Analyzer::Expr * getArg(const size_t i) const
DEVICE void iota(ARGS &&...args)
void addEdge(const node_t from, const node_t to)
std::vector< std::unordered_set< node_t > > inbound_
std::vector< node_t > traverse_join_cost_graph(const std::vector< std::map< node_t, cost_t >> &join_cost_graph, const std::vector< InputTableInfo > &table_infos, const std::function< bool(const node_t lhs_nest_level, const node_t rhs_nest_level)> &compare_node, const std::function< bool(const TraversalEdge &, const TraversalEdge &)> &compare_edge, const JoinQualsPerNestingLevel &left_deep_join_quals, std::vector< std::map< node_t, InnerQualDecision >> &qual_normalization_res)
AccessManager::Decision Decision
InnerQualDecision inner_qual_decision
const std::shared_ptr< Analyzer::Expr > get_own_left_operand() const
static bool is_point_poly_rewrite_target_func(std::string_view target_func_name)