OmniSciDB  1dac507f6e
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
RelAlgOptimizer.h File Reference
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
+ Include dependency graph for RelAlgOptimizer.h:
+ This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

std::unordered_map< const
RelAlgNode
*, std::unordered_set< const
RelAlgNode * > > 
build_du_web (const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
 
void eliminate_identical_copy (std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
 
void eliminate_dead_columns (std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
 
void fold_filters (std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
 
void hoist_filter_cond_to_cross_join (std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
 
void simplify_sort (std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
 
void sink_projected_boolean_expr_to_join (std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
 

Function Documentation

std::unordered_map<const RelAlgNode*, std::unordered_set<const RelAlgNode*> > build_du_web ( const std::vector< std::shared_ptr< RelAlgNode >> &  nodes)
noexcept

Definition at line 387 of file RelAlgOptimizer.cpp.

References CHECK(), join(), and src.

Referenced by eliminate_dead_columns(), eliminate_identical_copy(), fold_filters(), hoist_filter_cond_to_cross_join(), and sink_projected_boolean_expr_to_join().

388  {
389  std::unordered_map<const RelAlgNode*, std::unordered_set<const RelAlgNode*>> web;
390  std::unordered_set<const RelAlgNode*> visited;
391  std::vector<const RelAlgNode*> work_set;
392  for (auto node : nodes) {
393  if (std::dynamic_pointer_cast<RelScan>(node) ||
394  std::dynamic_pointer_cast<RelModify>(node) || visited.count(node.get())) {
395  continue;
396  }
397  work_set.push_back(node.get());
398  while (!work_set.empty()) {
399  auto walker = work_set.back();
400  work_set.pop_back();
401  if (visited.count(walker)) {
402  continue;
403  }
404  CHECK(!web.count(walker));
405  auto it_ok =
406  web.insert(std::make_pair(walker, std::unordered_set<const RelAlgNode*>{}));
407  CHECK(it_ok.second);
408  visited.insert(walker);
409  const auto join = dynamic_cast<const RelJoin*>(walker);
410  const auto project = dynamic_cast<const RelProject*>(walker);
411  const auto aggregate = dynamic_cast<const RelAggregate*>(walker);
412  const auto filter = dynamic_cast<const RelFilter*>(walker);
413  const auto sort = dynamic_cast<const RelSort*>(walker);
414  const auto left_deep_join = dynamic_cast<const RelLeftDeepInnerJoin*>(walker);
415  const auto logical_values = dynamic_cast<const RelLogicalValues*>(walker);
416  const auto table_func = dynamic_cast<const RelTableFunction*>(walker);
417  CHECK(join || project || aggregate || filter || sort || left_deep_join ||
418  logical_values || table_func);
419  for (size_t i = 0; i < walker->inputCount(); ++i) {
420  auto src = walker->getInput(i);
421  if (dynamic_cast<const RelScan*>(src) || dynamic_cast<const RelModify*>(src)) {
422  continue;
423  }
424  if (web.empty() || !web.count(src)) {
425  web.insert(std::make_pair(src, std::unordered_set<const RelAlgNode*>{}));
426  }
427  web[src].insert(walker);
428  work_set.push_back(src);
429  }
430  }
431  }
432  return web;
433 }
std::string join(T const &container, std::string const &delim)
int64_t * src
CHECK(cgen_state)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void eliminate_dead_columns ( std::vector< std::shared_ptr< RelAlgNode >> &  nodes)
noexcept

Definition at line 1112 of file RelAlgOptimizer.cpp.

References anonymous_namespace{RelAlgOptimizer.cpp}::any_dead_col_in(), build_du_web(), CHECK(), anonymous_namespace{RelAlgOptimizer.cpp}::does_redef_cols(), LOG, anonymous_namespace{RelAlgOptimizer.cpp}::mark_live_columns(), anonymous_namespace{RelAlgOptimizer.cpp}::propagate_input_renumbering(), anonymous_namespace{RelAlgOptimizer.cpp}::sweep_dead_columns(), anonymous_namespace{RelAlgOptimizer.cpp}::try_insert_coalesceable_proj(), and logger::WARNING.

Referenced by anonymous_namespace{RelAlgAbstractInterpreter.cpp}::RelAlgAbstractInterpreter::run().

1112  {
1113  if (nodes.empty()) {
1114  return;
1115  }
1116  auto root = nodes.back().get();
1117  if (!root) {
1118  return;
1119  }
1120  CHECK(!dynamic_cast<const RelScan*>(root) && !dynamic_cast<const RelJoin*>(root));
1121  // Mark
1122  auto old_liveouts = mark_live_columns(nodes);
1123  std::unordered_set<const RelAlgNode*> intact_nodes;
1124  bool has_dead_cols = false;
1125  for (auto live_pair : old_liveouts) {
1126  auto node = live_pair.first;
1127  const auto& outs = live_pair.second;
1128  if (outs.empty()) {
1129  LOG(WARNING) << "RA node with no used column: " << node->toString();
1130  // Ignore empty live_out due to some invalid node
1131  intact_nodes.insert(node);
1132  }
1133  if (any_dead_col_in(node, outs)) {
1134  has_dead_cols = true;
1135  } else {
1136  intact_nodes.insert(node);
1137  }
1138  }
1139  if (!has_dead_cols) {
1140  return;
1141  }
1142  auto web = build_du_web(nodes);
1143  try_insert_coalesceable_proj(nodes, old_liveouts, web);
1144 
1145  for (auto node : nodes) {
1146  if (intact_nodes.count(node.get()) || does_redef_cols(node.get())) {
1147  continue;
1148  }
1149  bool intact = true;
1150  for (size_t i = 0; i < node->inputCount(); ++i) {
1151  auto source = node->getInput(i);
1152  if (!dynamic_cast<const RelScan*>(source) && !intact_nodes.count(source)) {
1153  intact = false;
1154  break;
1155  }
1156  }
1157  if (intact) {
1158  intact_nodes.insert(node.get());
1159  }
1160  }
1161 
1162  std::unordered_map<const RelAlgNode*, size_t> orig_node_sizes;
1163  for (auto node : nodes) {
1164  orig_node_sizes.insert(std::make_pair(node.get(), node->size()));
1165  }
1166  // Sweep
1167  std::unordered_map<const RelAlgNode*, std::unordered_map<size_t, size_t>>
1168  liveout_renumbering;
1169  std::vector<const RelAlgNode*> ready_nodes;
1170  std::tie(liveout_renumbering, ready_nodes) =
1171  sweep_dead_columns(old_liveouts, nodes, intact_nodes, web, orig_node_sizes);
1172  // Propagate
1174  liveout_renumbering, ready_nodes, old_liveouts, intact_nodes, web, orig_node_sizes);
1175 }
std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * > > build_du_web(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
#define LOG(tag)
Definition: Logger.h:185
bool does_redef_cols(const RelAlgNode *node)
void propagate_input_renumbering(std::unordered_map< const RelAlgNode *, std::unordered_map< size_t, size_t >> &liveout_renumbering, const std::vector< const RelAlgNode * > &ready_nodes, const std::unordered_map< const RelAlgNode *, std::unordered_set< size_t >> &old_liveouts, const std::unordered_set< const RelAlgNode * > &intact_nodes, const std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * >> &du_web, const std::unordered_map< const RelAlgNode *, size_t > &orig_node_sizes)
std::unordered_map< const RelAlgNode *, std::unordered_set< size_t > > mark_live_columns(std::vector< std::shared_ptr< RelAlgNode >> &nodes)
CHECK(cgen_state)
bool any_dead_col_in(const RelAlgNode *node, const std::unordered_set< size_t > &live_outs)
void try_insert_coalesceable_proj(std::vector< std::shared_ptr< RelAlgNode >> &nodes, std::unordered_map< const RelAlgNode *, std::unordered_set< size_t >> &liveouts, std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * >> &du_web)
std::pair< std::unordered_map< const RelAlgNode *, std::unordered_map< size_t, size_t > >, std::vector< const RelAlgNode * > > sweep_dead_columns(const std::unordered_map< const RelAlgNode *, std::unordered_set< size_t >> &live_outs, const std::vector< std::shared_ptr< RelAlgNode >> &nodes, const std::unordered_set< const RelAlgNode * > &intact_nodes, const std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * >> &du_web, const std::unordered_map< const RelAlgNode *, size_t > &orig_node_sizes)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void eliminate_identical_copy ( std::vector< std::shared_ptr< RelAlgNode >> &  nodes)
noexcept

Definition at line 438 of file RelAlgOptimizer.cpp.

References build_du_web(), CHECK(), CHECK_EQ, anonymous_namespace{RelAlgOptimizer.cpp}::cleanup_dead_nodes(), anonymous_namespace{RelAlgOptimizer.cpp}::get_visible_projects(), anonymous_namespace{RelAlgOptimizer.cpp}::is_distinct(), anonymous_namespace{RelAlgOptimizer.cpp}::is_identical_copy(), and anonymous_namespace{RelAlgOptimizer.cpp}::redirect_inputs_of().

Referenced by anonymous_namespace{RelAlgAbstractInterpreter.cpp}::RelAlgAbstractInterpreter::run().

438  {
439  std::unordered_set<std::shared_ptr<const RelAlgNode>> copies;
440  auto sink = nodes.back();
441  for (auto node : nodes) {
442  auto aggregate = std::dynamic_pointer_cast<const RelAggregate>(node);
443  if (!aggregate || aggregate == sink ||
444  !(aggregate->getGroupByCount() == 1 && aggregate->getAggExprsCount() == 0)) {
445  continue;
446  }
447  auto project =
448  std::dynamic_pointer_cast<const RelProject>(aggregate->getAndOwnInput(0));
449  if (project && project->size() == aggregate->size() &&
450  project->getFields() == aggregate->getFields()) {
451  CHECK_EQ(size_t(0), copies.count(aggregate));
452  copies.insert(aggregate);
453  }
454  }
455  for (auto node : nodes) {
456  if (!node->inputCount()) {
457  continue;
458  }
459  auto last_source = node->getAndOwnInput(node->inputCount() - 1);
460  if (!copies.count(last_source)) {
461  continue;
462  }
463  auto aggregate = std::dynamic_pointer_cast<const RelAggregate>(last_source);
464  CHECK(aggregate);
465  if (!std::dynamic_pointer_cast<const RelJoin>(node) || aggregate->size() != 1) {
466  continue;
467  }
468  auto project =
469  std::dynamic_pointer_cast<const RelProject>(aggregate->getAndOwnInput(0));
470  CHECK(project);
471  CHECK_EQ(size_t(1), project->size());
472  if (!is_distinct(size_t(0), project.get())) {
473  continue;
474  }
475  auto new_source = project->getAndOwnInput(0);
476  if (std::dynamic_pointer_cast<const RelSort>(new_source) ||
477  std::dynamic_pointer_cast<const RelScan>(new_source)) {
478  node->replaceInput(last_source, new_source);
479  }
480  }
481  decltype(copies)().swap(copies);
482 
483  auto web = build_du_web(nodes);
484 
485  std::unordered_set<const RelProject*> projects;
486  std::unordered_set<const RelProject*> permutating_projects;
487  auto visible_projs = get_visible_projects(nodes.back().get());
488  for (auto node : nodes) {
489  auto project = std::dynamic_pointer_cast<RelProject>(node);
490  if (project && project->isSimple() &&
491  (!visible_projs.count(project.get()) || !project->isRenaming()) &&
492  is_identical_copy(project.get(), web, projects, permutating_projects)) {
493  projects.insert(project.get());
494  }
495  }
496 
497  for (auto node : nodes) {
498  redirect_inputs_of(node, projects, permutating_projects, web);
499  }
500 
501  cleanup_dead_nodes(nodes);
502 }
bool is_identical_copy(const RelProject *project, const std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * >> &du_web, const std::unordered_set< const RelProject * > &projects_to_remove, std::unordered_set< const RelProject * > &permutating_projects)
std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * > > build_du_web(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
#define CHECK_EQ(x, y)
Definition: Logger.h:198
void redirect_inputs_of(std::shared_ptr< RelAlgNode > node, const std::unordered_set< const RelProject * > &projects, const std::unordered_set< const RelProject * > &permutating_projects, const std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * >> &du_web)
std::unordered_set< const RelProject * > get_visible_projects(const RelAlgNode *root)
CHECK(cgen_state)
void cleanup_dead_nodes(std::vector< std::shared_ptr< RelAlgNode >> &nodes)
bool is_distinct(const size_t input_idx, const RelAlgNode *node)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void fold_filters ( std::vector< std::shared_ptr< RelAlgNode >> &  nodes)
noexcept

Definition at line 1365 of file RelAlgOptimizer.cpp.

References build_du_web(), CHECK(), CHECK_EQ, anonymous_namespace{RelAlgOptimizer.cpp}::cleanup_dead_nodes(), logger::INFO, kAND, kBOOLEAN, LOG, and anonymous_namespace{RelAlgOptimizer.cpp}::replace_all_usages().

Referenced by anonymous_namespace{RelAlgAbstractInterpreter.cpp}::RelAlgAbstractInterpreter::run().

1365  {
1366  std::unordered_map<const RelAlgNode*, std::shared_ptr<RelAlgNode>> deconst_mapping;
1367  for (auto node : nodes) {
1368  deconst_mapping.insert(std::make_pair(node.get(), node));
1369  }
1370 
1371  auto web = build_du_web(nodes);
1372  for (auto node_it = nodes.rbegin(); node_it != nodes.rend(); ++node_it) {
1373  auto& node = *node_it;
1374  if (auto filter = std::dynamic_pointer_cast<RelFilter>(node)) {
1375  CHECK_EQ(filter->inputCount(), size_t(1));
1376  auto src_filter = dynamic_cast<const RelFilter*>(filter->getInput(0));
1377  if (!src_filter) {
1378  continue;
1379  }
1380  auto siblings_it = web.find(src_filter);
1381  if (siblings_it == web.end() || siblings_it->second.size() != size_t(1)) {
1382  continue;
1383  }
1384  auto src_it = deconst_mapping.find(src_filter);
1385  CHECK(src_it != deconst_mapping.end());
1386  auto folded_filter = std::dynamic_pointer_cast<RelFilter>(src_it->second);
1387  CHECK(folded_filter);
1388  // TODO(miyu) : drop filter w/ only expression valued constant TRUE?
1389  if (auto rex_operator = dynamic_cast<const RexOperator*>(filter->getCondition())) {
1390  LOG(INFO) << "ID=" << filter->getId() << " " << filter->toString()
1391  << " folded into "
1392  << "ID=" << folded_filter->getId() << " " << folded_filter->toString()
1393  << std::endl;
1394  std::vector<std::unique_ptr<const RexScalar>> operands;
1395  operands.emplace_back(folded_filter->getAndReleaseCondition());
1396  auto old_condition = dynamic_cast<const RexOperator*>(operands.back().get());
1397  CHECK(old_condition && old_condition->getType().get_type() == kBOOLEAN);
1398  RexInputRedirector redirector(folded_filter.get(), folded_filter->getInput(0));
1399  operands.push_back(redirector.visit(rex_operator));
1400  auto other_condition = dynamic_cast<const RexOperator*>(operands.back().get());
1401  CHECK(other_condition && other_condition->getType().get_type() == kBOOLEAN);
1402  const bool notnull = old_condition->getType().get_notnull() &&
1403  other_condition->getType().get_notnull();
1404  auto new_condition = std::unique_ptr<const RexScalar>(
1405  new RexOperator(kAND, operands, SQLTypeInfo(kBOOLEAN, notnull)));
1406  folded_filter->setCondition(new_condition);
1407  replace_all_usages(filter, folded_filter, deconst_mapping, web);
1408  deconst_mapping.erase(filter.get());
1409  web.erase(filter.get());
1410  web[filter->getInput(0)].erase(filter.get());
1411  node.reset();
1412  }
1413  }
1414  }
1415 
1416  if (!nodes.empty()) {
1417  auto sink = nodes.back();
1418  for (auto node_it = std::next(nodes.rend()); !sink && node_it != nodes.rbegin();
1419  ++node_it) {
1420  sink = *node_it;
1421  }
1422  CHECK(sink);
1423  cleanup_dead_nodes(nodes);
1424  }
1425 }
std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * > > build_du_web(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
#define CHECK_EQ(x, y)
Definition: Logger.h:198
#define LOG(tag)
Definition: Logger.h:185
CHECK(cgen_state)
void cleanup_dead_nodes(std::vector< std::shared_ptr< RelAlgNode >> &nodes)
Definition: sqldefs.h:37
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:852
void replace_all_usages(std::shared_ptr< const RelAlgNode > old_def_node, std::shared_ptr< const RelAlgNode > new_def_node, std::unordered_map< const RelAlgNode *, std::shared_ptr< RelAlgNode >> &deconst_mapping, std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * >> &du_web)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void hoist_filter_cond_to_cross_join ( std::vector< std::shared_ptr< RelAlgNode >> &  nodes)
noexcept

Definition at line 1525 of file RelAlgOptimizer.cpp.

References build_du_web(), CHECK(), find_hoistable_conditions(), RelFilter::getCondition(), RelAlgNode::getInput(), INNER, join(), kAND, kBOOLEAN, RelFilter::setCondition(), and RexVisitorBase< T >::visit().

Referenced by anonymous_namespace{RelAlgAbstractInterpreter.cpp}::RelAlgAbstractInterpreter::run().

1526  {
1527  std::unordered_set<const RelAlgNode*> visited;
1528  auto web = build_du_web(nodes);
1529  for (auto node : nodes) {
1530  if (visited.count(node.get())) {
1531  continue;
1532  }
1533  visited.insert(node.get());
1534  auto join = dynamic_cast<RelJoin*>(node.get());
1535  if (join && join->getJoinType() == JoinType::INNER) {
1536  // Only allow cross join for now.
1537  if (auto literal = dynamic_cast<const RexLiteral*>(join->getCondition())) {
1538  // Assume Calcite always generates an inner join on constant boolean true for
1539  // cross join.
1540  CHECK(literal->getType() == kBOOLEAN && literal->getVal<bool>());
1541  size_t first_col_idx = 0;
1542  const RelFilter* filter = nullptr;
1543  std::vector<const RelJoin*> join_seq{join};
1544  for (const RelJoin* curr_join = join; !filter;) {
1545  auto usrs_it = web.find(curr_join);
1546  CHECK(usrs_it != web.end());
1547  if (usrs_it->second.size() != size_t(1)) {
1548  break;
1549  }
1550  auto only_usr = *usrs_it->second.begin();
1551  if (auto usr_join = dynamic_cast<const RelJoin*>(only_usr)) {
1552  if (join == usr_join->getInput(1)) {
1553  const auto src1_offset = usr_join->getInput(0)->size();
1554  first_col_idx += src1_offset;
1555  }
1556  join_seq.push_back(usr_join);
1557  curr_join = usr_join;
1558  continue;
1559  }
1560 
1561  filter = dynamic_cast<const RelFilter*>(only_usr);
1562  break;
1563  }
1564  if (!filter) {
1565  visited.insert(join_seq.begin(), join_seq.end());
1566  continue;
1567  }
1568  const auto src_join = dynamic_cast<const RelJoin*>(filter->getInput(0));
1569  CHECK(src_join);
1570  auto modified_filter = const_cast<RelFilter*>(filter);
1571 
1572  if (src_join == join) {
1573  std::unique_ptr<const RexScalar> filter_condition(
1574  modified_filter->getAndReleaseCondition());
1575  std::unique_ptr<const RexScalar> true_condition =
1576  boost::make_unique<RexLiteral>(true,
1577  kBOOLEAN,
1578  kBOOLEAN,
1579  unsigned(-2147483648),
1580  1,
1581  unsigned(-2147483648),
1582  1);
1583  modified_filter->setCondition(true_condition);
1584  join->setCondition(filter_condition);
1585  continue;
1586  }
1587  const auto src1_base = src_join->getInput(0)->size();
1588  auto source =
1589  first_col_idx < src1_base ? src_join->getInput(0) : src_join->getInput(1);
1590  first_col_idx =
1591  first_col_idx < src1_base ? first_col_idx : first_col_idx - src1_base;
1592  auto join_conditions =
1594  source,
1595  first_col_idx,
1596  first_col_idx + join->size() - 1);
1597  if (join_conditions.empty()) {
1598  continue;
1599  }
1600 
1601  JoinTargetRebaser rebaser(join, first_col_idx);
1602  if (join_conditions.size() == 1) {
1603  auto new_join_condition = rebaser.visit(*join_conditions.begin());
1604  join->setCondition(new_join_condition);
1605  } else {
1606  std::vector<std::unique_ptr<const RexScalar>> operands;
1607  bool notnull = true;
1608  for (size_t i = 0; i < join_conditions.size(); ++i) {
1609  operands.emplace_back(rebaser.visit(join_conditions[i]));
1610  auto old_subcond = dynamic_cast<const RexOperator*>(join_conditions[i]);
1611  CHECK(old_subcond && old_subcond->getType().get_type() == kBOOLEAN);
1612  notnull = notnull && old_subcond->getType().get_notnull();
1613  }
1614  auto new_join_condition = std::unique_ptr<const RexScalar>(
1615  new RexOperator(kAND, operands, SQLTypeInfo(kBOOLEAN, notnull)));
1616  join->setCondition(new_join_condition);
1617  }
1618 
1619  SubConditionRemover remover(join_conditions);
1620  auto new_filter_condition = remover.visit(filter->getCondition());
1621  modified_filter->setCondition(new_filter_condition);
1622  }
1623  }
1624  }
1625 }
std::vector< const RexScalar * > find_hoistable_conditions(const RexScalar *condition, const RelAlgNode *source, const size_t first_col_idx, const size_t last_col_idx)
std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * > > build_du_web(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
void setCondition(std::unique_ptr< const RexScalar > &condition)
const RexScalar * getCondition() const
std::string join(T const &container, std::string const &delim)
CHECK(cgen_state)
const RelAlgNode * getInput(const size_t idx) const
Definition: sqldefs.h:37
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:852

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void simplify_sort ( std::vector< std::shared_ptr< RelAlgNode >> &  nodes)
noexcept

Definition at line 1630 of file RelAlgOptimizer.cpp.

References RelAlgNode::replaceInput().

Referenced by anonymous_namespace{RelAlgAbstractInterpreter.cpp}::RelAlgAbstractInterpreter::run().

1630  {
1631  if (nodes.size() < 3) {
1632  return;
1633  }
1634  for (size_t i = 0; i <= nodes.size() - 3;) {
1635  auto first_sort = std::dynamic_pointer_cast<RelSort>(nodes[i]);
1636  const auto project = std::dynamic_pointer_cast<const RelProject>(nodes[i + 1]);
1637  auto second_sort = std::dynamic_pointer_cast<RelSort>(nodes[i + 2]);
1638  if (first_sort && second_sort && project && project->isIdentity() &&
1639  *first_sort == *second_sort) {
1640  second_sort->replaceInput(second_sort->getAndOwnInput(0),
1641  first_sort->getAndOwnInput(0));
1642  nodes[i].reset();
1643  nodes[i + 1].reset();
1644  i += 3;
1645  } else {
1646  ++i;
1647  }
1648  }
1649 
1650  std::vector<std::shared_ptr<RelAlgNode>> new_nodes;
1651  for (auto node : nodes) {
1652  if (!node) {
1653  continue;
1654  }
1655  new_nodes.push_back(node);
1656  }
1657  nodes.swap(new_nodes);
1658 }
virtual void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void sink_projected_boolean_expr_to_join ( std::vector< std::shared_ptr< RelAlgNode >> &  nodes)
noexcept

Definition at line 1217 of file RelAlgOptimizer.cpp.

References build_du_web(), CHECK(), CHECK_EQ, join(), kBOOLEAN, and anonymous_namespace{RelAlgOptimizer.cpp}::mark_live_columns().

Referenced by anonymous_namespace{RelAlgAbstractInterpreter.cpp}::RelAlgAbstractInterpreter::run().

1218  {
1219  auto web = build_du_web(nodes);
1220  auto liveouts = mark_live_columns(nodes);
1221  for (auto node : nodes) {
1222  auto project = std::dynamic_pointer_cast<RelProject>(node);
1223  // TODO(miyu): relax RelScan limitation
1224  if (!project || project->isSimple() ||
1225  !dynamic_cast<const RelScan*>(project->getInput(0))) {
1226  continue;
1227  }
1228  auto usrs_it = web.find(project.get());
1229  CHECK(usrs_it != web.end());
1230  auto& usrs = usrs_it->second;
1231  if (usrs.size() != 1) {
1232  continue;
1233  }
1234  auto join = dynamic_cast<RelJoin*>(const_cast<RelAlgNode*>(*usrs.begin()));
1235  if (!join) {
1236  continue;
1237  }
1238  auto outs_it = liveouts.find(join);
1239  CHECK(outs_it != liveouts.end());
1240  std::unordered_map<size_t, size_t> in_to_out_index;
1241  std::unordered_set<size_t> boolean_expr_indicies;
1242  bool discarded = false;
1243  for (size_t i = 0; i < project->size(); ++i) {
1244  auto oper = dynamic_cast<const RexOperator*>(project->getProjectAt(i));
1245  if (oper && oper->getType().get_type() == kBOOLEAN) {
1246  boolean_expr_indicies.insert(i);
1247  } else {
1248  // TODO(miyu): relax?
1249  if (auto input = dynamic_cast<const RexInput*>(project->getProjectAt(i))) {
1250  in_to_out_index.insert(std::make_pair(input->getIndex(), i));
1251  } else {
1252  discarded = true;
1253  }
1254  }
1255  }
1256  if (discarded || boolean_expr_indicies.empty()) {
1257  continue;
1258  }
1259  const size_t index_base =
1260  join->getInput(0) == project.get() ? 0 : join->getInput(0)->size();
1261  for (auto i : boolean_expr_indicies) {
1262  auto join_idx = index_base + i;
1263  if (outs_it->second.count(join_idx)) {
1264  discarded = true;
1265  break;
1266  }
1267  }
1268  if (discarded) {
1269  continue;
1270  }
1271  RexInputCollector collector(project.get());
1272  std::vector<size_t> unloaded_input_indices;
1273  std::unordered_map<size_t, std::unique_ptr<const RexScalar>> in_idx_to_new_subcond;
1274  // Given all are dead right after join, safe to sink
1275  for (auto i : boolean_expr_indicies) {
1276  auto rex_ins = collector.visit(project->getProjectAt(i));
1277  for (auto& in : rex_ins) {
1278  CHECK_EQ(in.getSourceNode(), project->getInput(0));
1279  if (!in_to_out_index.count(in.getIndex())) {
1280  auto curr_out_index = project->size() + unloaded_input_indices.size();
1281  in_to_out_index.insert(std::make_pair(in.getIndex(), curr_out_index));
1282  unloaded_input_indices.push_back(in.getIndex());
1283  }
1284  RexInputSinker sinker(in_to_out_index, project.get());
1285  in_idx_to_new_subcond.insert(
1286  std::make_pair(i, sinker.visit(project->getProjectAt(i))));
1287  }
1288  }
1289  if (in_idx_to_new_subcond.empty()) {
1290  continue;
1291  }
1292  std::vector<std::unique_ptr<const RexScalar>> new_projections;
1293  for (size_t i = 0; i < project->size(); ++i) {
1294  if (boolean_expr_indicies.count(i)) {
1295  new_projections.push_back(boost::make_unique<RexInput>(project->getInput(0), 0));
1296  } else {
1297  auto rex_input = dynamic_cast<const RexInput*>(project->getProjectAt(i));
1298  CHECK(rex_input != nullptr);
1299  new_projections.push_back(rex_input->deepCopy());
1300  }
1301  }
1302  for (auto i : unloaded_input_indices) {
1303  new_projections.push_back(boost::make_unique<RexInput>(project->getInput(0), i));
1304  }
1305  project->setExpressions(new_projections);
1306 
1307  SubConditionReplacer replacer(in_idx_to_new_subcond);
1308  auto new_condition = replacer.visit(join->getCondition());
1309  join->setCondition(new_condition);
1310  }
1311 }
std::unordered_map< const RelAlgNode *, std::unordered_set< const RelAlgNode * > > build_du_web(const std::vector< std::shared_ptr< RelAlgNode >> &nodes) noexcept
#define CHECK_EQ(x, y)
Definition: Logger.h:198
std::string join(T const &container, std::string const &delim)
std::unordered_map< const RelAlgNode *, std::unordered_set< size_t > > mark_live_columns(std::vector< std::shared_ptr< RelAlgNode >> &nodes)
CHECK(cgen_state)

+ Here is the call graph for this function:

+ Here is the caller graph for this function: