25 const std::shared_ptr<RelFilter>& filter,
26 std::vector<std::shared_ptr<const RelAlgNode>> inputs,
27 std::vector<std::shared_ptr<const RelJoin>>& original_joins)
28 : condition_(filter ? filter->getAndReleaseCondition() : nullptr)
29 , original_filter_(filter)
30 , original_joins_(original_joins) {
31 std::vector<std::unique_ptr<const RexScalar>> operands;
32 bool is_notnull =
true;
36 for (
size_t nesting_level = 0; nesting_level < original_joins.size(); ++nesting_level) {
37 const auto& original_join = original_joins[nesting_level];
38 const auto condition_true =
39 dynamic_cast<const RexLiteral*
>(original_join->getCondition());
40 if (!condition_true || !condition_true->getVal<
bool>()) {
41 if (dynamic_cast<const RexOperator*>(original_join->getCondition())) {
43 is_notnull &&
dynamic_cast<const RexOperator*
>(original_join->getCondition())
47 switch (original_join->getJoinType()) {
49 if (original_join->getCondition()) {
50 operands.emplace_back(original_join->getAndReleaseCondition());
55 if (original_join->getCondition()) {
57 original_join->getAndReleaseCondition());
66 if (!operands.empty()) {
74 if (operands.size() > 1) {
84 for (
const auto& input : inputs) {
94 const size_t nesting_level)
const {
105 "(RelLeftDeepInnerJoin<" +
std::to_string(reinterpret_cast<uint64_t>(
this)) +
">(";
107 for (
const auto& input :
inputs_) {
108 result +=
" " + input->toString();
115 size_t total_size = 0;
116 for (
const auto& input :
inputs_) {
117 total_size += input->size();
132 if (original_join.get() == node) {
142 std::deque<std::shared_ptr<const RelAlgNode>>& inputs,
143 std::vector<std::shared_ptr<const RelJoin>>& original_joins,
144 const std::shared_ptr<const RelJoin>&
join) {
145 original_joins.push_back(join);
146 CHECK_EQ(
size_t(2), join->inputCount());
147 const auto left_input_join =
148 std::dynamic_pointer_cast<
const RelJoin>(join->getAndOwnInput(0));
149 if (left_input_join) {
150 inputs.push_front(join->getAndOwnInput(1));
153 inputs.push_front(join->getAndOwnInput(1));
154 inputs.push_front(join->getAndOwnInput(0));
158 std::pair<std::shared_ptr<RelLeftDeepInnerJoin>, std::shared_ptr<const RelAlgNode>>
162 return {
nullptr,
nullptr};
164 std::deque<std::shared_ptr<const RelAlgNode>> inputs_deque;
165 const auto left_deep_join_filter =
166 std::dynamic_pointer_cast<
RelFilter>(left_deep_join_root);
168 std::dynamic_pointer_cast<
const RelJoin>(left_deep_join_root->getAndOwnInput(0));
170 std::vector<std::shared_ptr<const RelJoin>> original_joins;
172 std::vector<std::shared_ptr<const RelAlgNode>> inputs(inputs_deque.begin(),
174 return {std::make_shared<RelLeftDeepInnerJoin>(
175 left_deep_join_filter, inputs, original_joins),
182 : left_deep_join_(left_deep_join) {
183 std::vector<size_t> input_sizes;
188 input_size_prefix_sums_.resize(input_sizes.size());
190 input_sizes.begin(), input_sizes.end(), input_size_prefix_sums_.begin());
195 if (left_deep_join_->coversOriginalNode(source_node)) {
197 input_size_prefix_sums_.end(),
199 std::less_equal<size_t>());
200 CHECK(it != input_size_prefix_sums_.end());
201 const auto input_node =
202 left_deep_join_->getInput(std::distance(input_size_prefix_sums_.begin(), it));
203 if (it != input_size_prefix_sums_.begin()) {
204 const auto prev_input_count = *(it - 1);
206 const auto input_index = rex_input->
getIndex() - prev_input_count;
215 std::vector<size_t> input_size_prefix_sums_;
225 const std::shared_ptr<RelAlgNode>& node) {
226 const auto left_deep_join_filter =
dynamic_cast<const RelFilter*
>(node.get());
227 if (left_deep_join_filter) {
228 const auto join =
dynamic_cast<const RelJoin*
>(left_deep_join_filter->getInput(0));
236 if (!node || node->inputCount() != 1) {
239 const auto join =
dynamic_cast<const RelJoin*
>(node->getInput(0));
248 RebindRexInputsFromLeftDeepJoin rebind_rex_inputs_from_left_deep_join(left_deep_join);
249 rebind_rex_inputs_from_left_deep_join.visit(rex);
253 std::list<std::shared_ptr<RelAlgNode>> new_nodes;
254 for (
auto& left_deep_join_candidate : nodes) {
255 std::shared_ptr<RelLeftDeepInnerJoin> left_deep_join;
256 std::shared_ptr<const RelAlgNode> old_root;
258 if (!left_deep_join) {
261 CHECK_GE(left_deep_join->inputCount(), size_t(2));
262 for (
size_t nesting_level = 1; nesting_level <= left_deep_join->inputCount() - 1;
264 const auto outer_condition = left_deep_join->getOuterCondition(nesting_level);
265 if (outer_condition) {
270 left_deep_join.get());
271 for (
auto& node : nodes) {
272 if (node && node->hasInput(old_root.get())) {
273 node->replaceInput(left_deep_join_candidate, left_deep_join);
274 std::shared_ptr<const RelJoin> old_join;
275 if (std::dynamic_pointer_cast<const RelJoin>(left_deep_join_candidate)) {
276 old_join = std::static_pointer_cast<
const RelJoin>(left_deep_join_candidate);
278 CHECK_EQ(
size_t(1), left_deep_join_candidate->inputCount());
279 old_join = std::dynamic_pointer_cast<
const RelJoin>(
280 left_deep_join_candidate->getAndOwnInput(0));
285 std::dynamic_pointer_cast<
const RelJoin>(old_join->getAndOwnInput(0));
290 new_nodes.emplace_back(std::move(left_deep_join));
296 nodes.insert(nodes.begin(), new_nodes.begin(), new_nodes.end());
std::vector< std::unique_ptr< const RexScalar > > outer_conditions_per_level_
size_t size() const override
std::shared_ptr< const RelAlgNode > get_left_deep_join_root(const std::shared_ptr< RelAlgNode > &node)
bool coversOriginalNode(const RelAlgNode *node) const
std::pair< std::shared_ptr< RelLeftDeepInnerJoin >, std::shared_ptr< const RelAlgNode > > create_left_deep_join(const std::shared_ptr< RelAlgNode > &left_deep_join_root)
const RexScalar * getOuterCondition(const size_t nesting_level) const
void addManagedInput(std::shared_ptr< const RelAlgNode > input)
std::shared_ptr< const RelAlgNode > getAndOwnInput(const size_t idx) const
std::string toString() const override
const std::vector< std::shared_ptr< const RelJoin > > original_joins_
std::shared_ptr< RelAlgNode > deepCopy() const override
DEVICE void partial_sum(ARGS &&...args)
const RelAlgNode * getInput(const size_t idx) const
DEVICE auto lower_bound(ARGS &&...args)
virtual size_t size() const =0
std::unique_ptr< const RexScalar > condition_
RelLeftDeepInnerJoin(const std::shared_ptr< RelFilter > &filter, RelAlgInputs inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins)
const RexScalar * getInnerCondition() const
void collect_left_deep_join_inputs(std::deque< std::shared_ptr< const RelAlgNode >> &inputs, std::vector< std::shared_ptr< const RelJoin >> &original_joins, const std::shared_ptr< const RelJoin > &join)
void replaceInput(std::shared_ptr< const RelAlgNode > old_input, std::shared_ptr< const RelAlgNode > input) override
const std::shared_ptr< RelFilter > original_filter_
const size_t inputCount() const
void rebind_inputs_from_left_deep_join(const RexScalar *rex, const RelLeftDeepInnerJoin *left_deep_join)