OmniSciDB  cde582ebc3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
WindowFunctionIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "WindowContext.h"
20 
21 llvm::Value* Executor::codegenWindowFunction(const size_t target_index,
22  const CompilationOptions& co) {
23  AUTOMATIC_IR_METADATA(cgen_state_.get());
24  CodeGenerator code_generator(this);
25  const auto window_func_context =
27  target_index);
28  const auto window_func = window_func_context->getWindowFunction();
29  switch (window_func->getKind()) {
34  // they are always evaluated on the entire partition
35  return code_generator.codegenWindowPosition(window_func_context,
36  code_generator.posArg(nullptr));
37  }
40  // they are always evaluated on the entire partition
41  return cgen_state_->emitCall("percent_window_func",
42  {cgen_state_->llInt(reinterpret_cast<const int64_t>(
43  window_func_context->output())),
44  code_generator.posArg(nullptr)});
45  }
50  // they are always evaluated on the current frame
52  const auto& args = window_func->getArgs();
53  CHECK(!args.empty());
54  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
55  CHECK_EQ(arg_lvs.size(), size_t(1));
56  return arg_lvs.front();
57  }
63  // they are always evaluated on the current frame
64  return codegenWindowFunctionAggregate(co);
65  }
66  default: {
67  LOG(FATAL) << "Invalid window function kind";
68  }
69  }
70  return nullptr;
71 }
72 
73 namespace {
74 
76  const SQLTypeInfo& window_func_ti) {
77  std::string agg_name;
78  switch (kind) {
80  agg_name = "agg_min";
81  break;
82  }
84  agg_name = "agg_max";
85  break;
86  }
89  agg_name = "agg_sum";
90  break;
91  }
93  agg_name = "agg_count";
94  break;
95  }
96  default: {
97  LOG(FATAL) << "Invalid window function kind";
98  }
99  }
100  switch (window_func_ti.get_type()) {
101  case kFLOAT: {
102  agg_name += "_float";
103  break;
104  }
105  case kDOUBLE: {
106  agg_name += "_double";
107  break;
108  }
109  default: {
110  break;
111  }
112  }
113  return agg_name;
114 }
115 
117  const auto& args = window_func->getArgs();
118  return ((window_func->getKind() == SqlWindowFunctionKind::COUNT && !args.empty()) ||
119  window_func->getKind() == SqlWindowFunctionKind::AVG)
120  ? args.front()->get_type_info()
121  : window_func->get_type_info();
122 }
123 
124 } // namespace
125 
127  AUTOMATIC_IR_METADATA(cgen_state_.get());
128  const auto window_func_context =
130  const auto window_func = window_func_context->getWindowFunction();
131  const auto arg_ti = get_adjusted_window_type_info(window_func);
132  llvm::Type* aggregate_state_type =
133  arg_ti.get_type() == kFLOAT
134  ? llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0)
135  : llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
136  const auto aggregate_state_i64 = cgen_state_->llInt(
137  reinterpret_cast<const int64_t>(window_func_context->aggregateState()));
138  return cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_i64,
139  aggregate_state_type);
140 }
141 
143  AUTOMATIC_IR_METADATA(cgen_state_.get());
144  const auto reset_state_false_bb = codegenWindowResetStateControlFlow();
145  auto aggregate_state = aggregateWindowStatePtr();
146  llvm::Value* aggregate_state_count = nullptr;
147  const auto window_func_context =
149  const auto window_func = window_func_context->getWindowFunction();
150  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
151  const auto aggregate_state_count_i64 = cgen_state_->llInt(
152  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
153  const auto pi64_type =
154  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
155  aggregate_state_count =
156  cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_count_i64, pi64_type);
157  }
158  codegenWindowFunctionStateInit(aggregate_state);
159  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
160  const auto count_zero = cgen_state_->llInt(int64_t(0));
161  cgen_state_->emitCall("agg_id", {aggregate_state_count, count_zero});
162  }
163  cgen_state_->ir_builder_.CreateBr(reset_state_false_bb);
164  cgen_state_->ir_builder_.SetInsertPoint(reset_state_false_bb);
166  return codegenWindowFunctionAggregateCalls(aggregate_state, co);
167 }
168 
170  AUTOMATIC_IR_METADATA(cgen_state_.get());
171  const auto window_func_context =
173  const auto bitset = cgen_state_->llInt(
174  reinterpret_cast<const int64_t>(window_func_context->partitionStart()));
175  const auto min_val = cgen_state_->llInt(int64_t(0));
176  const auto max_val = cgen_state_->llInt(window_func_context->elementCount() - 1);
177  const auto null_val = cgen_state_->llInt(inline_int_null_value<int64_t>());
178  const auto null_bool_val = cgen_state_->llInt<int8_t>(inline_int_null_value<int8_t>());
179  CodeGenerator code_generator(this);
180  const auto reset_state =
181  code_generator.toBool(cgen_state_->emitCall("bit_is_set",
182  {bitset,
183  code_generator.posArg(nullptr),
184  min_val,
185  max_val,
186  null_val,
187  null_bool_val}));
188  const auto reset_state_true_bb = llvm::BasicBlock::Create(
189  cgen_state_->context_, "reset_state.true", cgen_state_->current_func_);
190  const auto reset_state_false_bb = llvm::BasicBlock::Create(
191  cgen_state_->context_, "reset_state.false", cgen_state_->current_func_);
192  cgen_state_->ir_builder_.CreateCondBr(
193  reset_state, reset_state_true_bb, reset_state_false_bb);
194  cgen_state_->ir_builder_.SetInsertPoint(reset_state_true_bb);
195  return reset_state_false_bb;
196 }
197 
198 void Executor::codegenWindowFunctionStateInit(llvm::Value* aggregate_state) {
199  AUTOMATIC_IR_METADATA(cgen_state_.get());
200  const auto window_func_context =
202  const auto window_func = window_func_context->getWindowFunction();
203  const auto window_func_ti = get_adjusted_window_type_info(window_func);
204  const auto window_func_null_val =
205  window_func_ti.is_fp()
206  ? cgen_state_->inlineFpNull(window_func_ti)
207  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
208  llvm::Value* window_func_init_val;
209  if (window_func_context->getWindowFunction()->getKind() ==
211  switch (window_func_ti.get_type()) {
212  case kFLOAT: {
213  window_func_init_val = cgen_state_->llFp(float(0));
214  break;
215  }
216  case kDOUBLE: {
217  window_func_init_val = cgen_state_->llFp(double(0));
218  break;
219  }
220  default: {
221  window_func_init_val = cgen_state_->llInt(int64_t(0));
222  break;
223  }
224  }
225  } else {
226  window_func_init_val = window_func_null_val;
227  }
228  const auto pi32_type =
229  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
230  switch (window_func_ti.get_type()) {
231  case kDOUBLE: {
232  cgen_state_->emitCall("agg_id_double", {aggregate_state, window_func_init_val});
233  break;
234  }
235  case kFLOAT: {
236  aggregate_state =
237  cgen_state_->ir_builder_.CreateBitCast(aggregate_state, pi32_type);
238  cgen_state_->emitCall("agg_id_float", {aggregate_state, window_func_init_val});
239  break;
240  }
241  default: {
242  cgen_state_->emitCall("agg_id", {aggregate_state, window_func_init_val});
243  break;
244  }
245  }
246 }
247 
248 llvm::Value* Executor::codegenWindowFunctionAggregateCalls(llvm::Value* aggregate_state,
249  const CompilationOptions& co) {
250  AUTOMATIC_IR_METADATA(cgen_state_.get());
251  const auto window_func_context =
253  const auto window_func = window_func_context->getWindowFunction();
254  const auto window_func_ti = get_adjusted_window_type_info(window_func);
255  const auto window_func_null_val =
256  window_func_ti.is_fp()
257  ? cgen_state_->inlineFpNull(window_func_ti)
258  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
259  const auto& args = window_func->getArgs();
260  llvm::Value* crt_val;
261  CodeGenerator code_generator(this);
262  if (args.empty()) {
263  CHECK(window_func->getKind() == SqlWindowFunctionKind::COUNT);
264  crt_val = cgen_state_->llInt(int64_t(1));
265  } else {
266  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
267  CHECK_EQ(arg_lvs.size(), size_t(1));
268  if (window_func->getKind() == SqlWindowFunctionKind::SUM && !window_func_ti.is_fp()) {
269  crt_val = code_generator.codegenCastBetweenIntTypes(
270  arg_lvs.front(), args.front()->get_type_info(), window_func_ti, false);
271  } else {
272  crt_val = window_func_ti.get_type() == kFLOAT
273  ? arg_lvs.front()
274  : cgen_state_->castToTypeIn(arg_lvs.front(), 64);
275  }
276  }
277  const auto agg_name = get_window_agg_name(window_func->getKind(), window_func_ti);
278  if (window_func_context->needsToBuildAggregateTree()) {
279  // compute an aggregated value for each row of the window frame by using segment tree
280  // when constructing a window context, we build a necessary segment tree for it
281  // and use the tree array (so called `aggregate tree`) to query the aggregated value
282  // of the specific window frame
283  // we fall back to the non-framing window func evaluation logic if an input
284  // of the window function can be an empty one
285  const auto pi64_type =
286  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
287  const auto pi32_type =
288  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
289  const auto ppi64_type = llvm::PointerType::get(pi64_type, 0);
290  // this lambda function is only used for window framing codegen
291  auto get_col_type_name_for_framing = [](const SQLTypes type) {
292  switch (type) {
293  case kTINYINT:
294  return "int8_t";
295  case kSMALLINT:
296  return "int16_t";
297  case kINT:
298  return "int32_t";
299  case kBIGINT:
300  return "int64_t";
301  case kFLOAT:
302  return "float";
303  case kDOUBLE:
304  case kNUMERIC:
305  case kDECIMAL:
306  return "double";
307  default: {
308  UNREACHABLE();
309  return "UNREACHABLE";
310  }
311  }
312  };
313  // row_id of the current row in partition, which may be different from row_id in a
314  // table, i.e., pos_arg
315  const auto current_row_pos = code_generator.posArg(nullptr);
316 
317  // # elems per partition
318  const auto partition_count_buf =
319  cgen_state_->llInt(reinterpret_cast<int64_t>(window_func_context->counts()));
320  const auto partition_count_buf_ptr =
321  cgen_state_->ir_builder_.CreateIntToPtr(partition_count_buf, pi32_type);
322 
323  // given current row's pos, calculate the partition index that it belongs to
324  const auto partition_count_lv =
325  cgen_state_->llInt(window_func_context->partitionCount());
326  const auto partition_num_count_buf = cgen_state_->llInt(
327  reinterpret_cast<int64_t>(window_func_context->partitionNumCountBuf()));
328  const auto partition_num_count_ptr =
329  cgen_state_->ir_builder_.CreateIntToPtr(partition_num_count_buf, pi64_type);
330  const auto partition_index_lv = cgen_state_->emitCall(
331  "compute_int64_t_lower_bound",
332  {partition_count_lv, current_row_pos, partition_num_count_ptr});
333 
334  // # elems of the given partition
335  const auto num_elem_current_partition_ptr =
336  cgen_state_->ir_builder_.CreateGEP(get_int_type(32, cgen_state_->context_),
337  partition_count_buf_ptr,
338  partition_index_lv);
339  const auto num_elem_current_partition_lv = cgen_state_->castToTypeIn(
340  cgen_state_->ir_builder_.CreateLoad(
341  num_elem_current_partition_ptr->getType()->getPointerElementType(),
342  num_elem_current_partition_ptr),
343  64);
344 
345  // partial sum of # elems of partitions
346  const auto partition_start_offset_buf = cgen_state_->llInt(
347  reinterpret_cast<int64_t>(window_func_context->partitionStartOffset()));
348  const auto partition_start_offset_ptr =
349  cgen_state_->ir_builder_.CreateIntToPtr(partition_start_offset_buf, pi64_type);
350 
351  // get start offset of the current partition
352  const auto current_partition_start_offset_ptr =
353  cgen_state_->ir_builder_.CreateGEP(get_int_type(64, cgen_state_->context_),
354  partition_start_offset_ptr,
355  partition_index_lv);
356  const auto current_partition_start_offset_lv = cgen_state_->ir_builder_.CreateLoad(
357  current_partition_start_offset_ptr->getType()->getPointerElementType(),
358  current_partition_start_offset_ptr);
359 
360  // a depth of segment tree
361  const auto tree_depth_buf = cgen_state_->llInt(
362  reinterpret_cast<int64_t>(window_func_context->getAggregateTreeDepth()));
363  const auto tree_depth_buf_ptr =
364  cgen_state_->ir_builder_.CreateIntToPtr(tree_depth_buf, pi64_type);
365  const auto current_partition_tree_depth_buf_ptr = cgen_state_->ir_builder_.CreateGEP(
366  get_int_type(64, cgen_state_->context_), tree_depth_buf_ptr, partition_index_lv);
367  const auto current_partition_tree_depth_lv = cgen_state_->ir_builder_.CreateLoad(
368  current_partition_tree_depth_buf_ptr->getType()->getPointerElementType(),
369  current_partition_tree_depth_buf_ptr);
370 
371  // a fanout of the current partition's segment tree
372  const auto aggregation_tree_fanout_lv = cgen_state_->llInt(
373  static_cast<int64_t>(window_func_context->getAggregateTreeFanout()));
374 
375  // agg_type
376  const auto agg_type_lv =
377  cgen_state_->llInt(static_cast<int32_t>(window_func->getKind()));
378 
379  // declare various variables to codegen
380  const auto frame_start_bound = window_func->getFrameStartBound();
381  const auto frame_end_bound = window_func->getFrameEndBound();
382  llvm::Value* order_key_buf_ptr{nullptr};
383  llvm::Value* target_partition_rowid_ptr{nullptr};
384  llvm::Value* target_partition_sorted_rowid_ptr{nullptr};
385  llvm::Value* current_col_value_lv{nullptr};
386  llvm::Value* order_key_col_null_val_lv{nullptr};
387  llvm::Value* null_start_pos_lv{nullptr};
388  llvm::Value* null_end_pos_lv{nullptr};
389  std::vector<llvm::Value*> frame_start_bound_expr_lvs;
390  std::vector<llvm::Value*> frame_end_bound_expr_lvs;
391  llvm::Value* frame_start_bound_expr_lv = nullptr;
392  llvm::Value* frame_end_bound_expr_lv = nullptr;
393  llvm::Value* frame_start_bound_lv = nullptr;
394  llvm::Value* frame_end_bound_lv = nullptr;
395 
396  // codegen frame bound expr if necessary
397  auto needs_bound_expr_codegen = [](const Analyzer::WindowFrame* window_frame) {
398  return window_frame->getBoundType() == SqlWindowFrameBoundType::EXPR_FOLLOWING ||
399  window_frame->getBoundType() == SqlWindowFrameBoundType::EXPR_PRECEDING;
400  };
401  if (needs_bound_expr_codegen(frame_start_bound)) {
402  frame_start_bound_expr_lvs =
403  code_generator.codegen(frame_start_bound->getBoundExpr(), true, co);
404  frame_start_bound_expr_lv = frame_start_bound_expr_lvs.front();
405  if (frame_start_bound->getBoundExpr()->get_type_info().get_size() != 8) {
406  frame_start_bound_expr_lv =
407  cgen_state_->castToTypeIn(frame_start_bound_expr_lv, 64);
408  }
409  } else {
410  frame_start_bound_expr_lv = cgen_state_->llInt((int64_t)-1);
411  }
412  if (needs_bound_expr_codegen(frame_end_bound)) {
413  frame_end_bound_expr_lvs =
414  code_generator.codegen(frame_end_bound->getBoundExpr(), true, co);
415  frame_end_bound_expr_lv = frame_end_bound_expr_lvs.front();
416  if (frame_end_bound->getBoundExpr()->get_type_info().get_size() != 8) {
417  frame_end_bound_expr_lv = cgen_state_->castToTypeIn(frame_end_bound_expr_lv, 64);
418  }
419  } else {
420  frame_end_bound_expr_lv = cgen_state_->llInt((int64_t)-1);
421  }
422 
423  // for range mode, we need to collect various info regarding ordering column
424  // to determine the frame boundary correctly
425  std::string order_col_type_name{""};
426  if (window_func->getFrameBoundType() ==
428  CHECK(window_func_context->getOrderKeyColumnBuffers().size() == 1);
429  CHECK(window_func->getOrderKeys().size() == 1UL);
430  CHECK(window_func_context->getOrderKeyColumnBuffers().size() == 1UL);
431  order_col_type_name = get_col_type_name_for_framing(
432  window_func_context->getOrderKeyColumnBufferTypes().front().get_type());
433  // ordering column buffer
434  size_t order_key_size =
435  window_func->getOrderKeys().front()->get_type_info().get_size() * 8;
436  const auto order_key_buf_type =
437  llvm::PointerType::get(get_int_type(order_key_size, cgen_state_->context_), 0);
438  const auto order_key_buf = cgen_state_->llInt(reinterpret_cast<int64_t>(
439  window_func_context->getOrderKeyColumnBuffers().front()));
440  order_key_buf_ptr =
441  cgen_state_->ir_builder_.CreateIntToPtr(order_key_buf, order_key_buf_type);
442 
443  // load column value of the current row (of ordering column)
444  const auto rowid_in_partition =
445  code_generator.codegenWindowPosition(window_func_context, current_row_pos);
446  const auto current_col_value_ptr = cgen_state_->ir_builder_.CreateGEP(
447  get_int_type(order_key_size, cgen_state_->context_),
448  order_key_buf_ptr,
449  rowid_in_partition);
450  current_col_value_lv = cgen_state_->ir_builder_.CreateLoad(
451  current_col_value_ptr->getType()->getPointerElementType(),
452  current_col_value_ptr,
453  "current_col_value");
454 
455  // row_id buf of the current partition
456  const auto partition_rowid_buf =
457  cgen_state_->llInt(reinterpret_cast<int64_t>(window_func_context->payload()));
458  const auto partition_rowid_ptr =
459  cgen_state_->ir_builder_.CreateIntToPtr(partition_rowid_buf, pi32_type);
460  target_partition_rowid_ptr =
461  cgen_state_->ir_builder_.CreateGEP(get_int_type(32, cgen_state_->context_),
462  partition_rowid_ptr,
463  current_partition_start_offset_lv);
464 
465  // row_id buf of ordered current partition
466  const auto sorted_partition_buf = cgen_state_->llInt(
467  reinterpret_cast<int64_t>(window_func_context->sortedPartition()));
468  const auto sorted_partition_buf_ptr =
469  cgen_state_->ir_builder_.CreateIntToPtr(sorted_partition_buf, pi64_type);
470  target_partition_sorted_rowid_ptr =
471  cgen_state_->ir_builder_.CreateGEP(get_int_type(64, cgen_state_->context_),
472  sorted_partition_buf_ptr,
473  current_partition_start_offset_lv);
474 
475  // null value of the ordering column
476  order_key_col_null_val_lv = cgen_state_->inlineNull(
477  window_func_context->getOrderKeyColumnBufferTypes().front());
478 
479  // null range of the aggregate tree
480  const auto null_start_pos_buf = cgen_state_->llInt(
481  reinterpret_cast<int64_t>(window_func_context->getNullValueStartPos()));
482  const auto null_start_pos_buf_ptr =
483  cgen_state_->ir_builder_.CreateIntToPtr(null_start_pos_buf, pi64_type);
484  const auto null_start_pos_ptr =
485  cgen_state_->ir_builder_.CreateGEP(get_int_type(64, cgen_state_->context_),
486  null_start_pos_buf_ptr,
487  partition_index_lv);
488  null_start_pos_lv = cgen_state_->ir_builder_.CreateLoad(
489  null_start_pos_ptr->getType()->getPointerElementType(),
490  null_start_pos_ptr,
491  "null_start_pos");
492  const auto null_end_pos_buf = cgen_state_->llInt(
493  reinterpret_cast<int64_t>(window_func_context->getNullValueEndPos()));
494  const auto null_end_pos_buf_ptr =
495  cgen_state_->ir_builder_.CreateIntToPtr(null_end_pos_buf, pi64_type);
496  const auto null_end_pos_ptr =
497  cgen_state_->ir_builder_.CreateGEP(get_int_type(64, cgen_state_->context_),
498  null_end_pos_buf_ptr,
499  partition_index_lv);
500  null_end_pos_lv = cgen_state_->ir_builder_.CreateLoad(
501  null_end_pos_ptr->getType()->getPointerElementType(),
502  null_end_pos_ptr,
503  "null_end_pos");
504  }
505 
506  // compute frame start depending on the bound type
507  if (frame_start_bound->getBoundType() ==
509  // frame starts at the first row of the partition
510  frame_start_bound_lv = cgen_state_->llInt((int64_t)0);
511  } else if (frame_start_bound->getBoundType() ==
513  // frame starts at the position before X rows of the current row
514  CHECK(frame_start_bound_expr_lv);
515  if (window_func->getFrameBoundType() ==
517  frame_start_bound_lv = cgen_state_->emitCall("compute_row_mode_start_index_sub",
518  {current_row_pos,
519  current_partition_start_offset_lv,
520  frame_start_bound_expr_lv});
521  } else {
522  CHECK(window_func->getFrameBoundType() ==
524  std::string lower_bound_func_name{"range_mode_"};
525  lower_bound_func_name.append(order_col_type_name);
526  lower_bound_func_name.append("_sub_frame_lower_bound");
527  frame_start_bound_lv = cgen_state_->emitCall(lower_bound_func_name,
528  {num_elem_current_partition_lv,
529  current_col_value_lv,
530  order_key_buf_ptr,
531  target_partition_rowid_ptr,
532  target_partition_sorted_rowid_ptr,
533  frame_start_bound_expr_lv,
534  order_key_col_null_val_lv,
535  null_start_pos_lv,
536  null_end_pos_lv});
537  }
538  } else if (frame_start_bound->getBoundType() ==
540  // frame start at the current row
541  if (window_func->getFrameBoundType() ==
543  frame_start_bound_lv = cgen_state_->emitCall("compute_row_mode_start_index_sub",
544  {current_row_pos,
545  current_partition_start_offset_lv,
546  cgen_state_->llInt(((int64_t)0))});
547  } else {
548  CHECK(window_func->getFrameBoundType() ==
550  std::string lower_bound_func_name{"compute_"};
551  lower_bound_func_name.append(order_col_type_name);
552  lower_bound_func_name.append("_lower_bound_from_ordered_index");
553  frame_start_bound_lv = cgen_state_->emitCall(lower_bound_func_name,
554  {num_elem_current_partition_lv,
555  current_col_value_lv,
556  order_key_buf_ptr,
557  target_partition_rowid_ptr,
558  target_partition_sorted_rowid_ptr,
559  order_key_col_null_val_lv,
560  null_start_pos_lv,
561  null_end_pos_lv});
562  }
563  } else if (frame_start_bound->getBoundType() ==
565  // frame start at the position after X rows of the current row
566  CHECK(frame_start_bound_expr_lv);
567  if (window_func->getFrameBoundType() ==
569  frame_start_bound_lv = cgen_state_->emitCall("compute_row_mode_start_index_add",
570  {current_row_pos,
571  current_partition_start_offset_lv,
572  frame_start_bound_expr_lv,
573  num_elem_current_partition_lv});
574  } else {
575  CHECK(window_func->getFrameBoundType() ==
577  std::string lower_bound_func_name{"range_mode_"};
578  lower_bound_func_name.append(order_col_type_name);
579  lower_bound_func_name.append("_add_frame_lower_bound");
580  frame_start_bound_lv = cgen_state_->emitCall(lower_bound_func_name,
581  {num_elem_current_partition_lv,
582  current_col_value_lv,
583  order_key_buf_ptr,
584  target_partition_rowid_ptr,
585  target_partition_sorted_rowid_ptr,
586  frame_start_bound_expr_lv,
587  order_key_col_null_val_lv,
588  null_start_pos_lv,
589  null_end_pos_lv});
590  }
591  } else {
592  CHECK(false) << "frame start cannot be UNBOUNDED FOLLOWING";
593  }
594 
595  // compute frame end
596  if (frame_end_bound->getBoundType() == SqlWindowFrameBoundType::UNBOUNDED_PRECEDING) {
597  // frame ends at the first row of the partition
598  CHECK(false) << "frame end cannot be UNBOUNDED PRECEDING";
599  } else if (frame_end_bound->getBoundType() ==
601  // frame ends at the position X rows before the current row
602  CHECK(frame_end_bound_expr_lv);
603  if (window_func->getFrameBoundType() ==
605  frame_end_bound_lv = cgen_state_->emitCall("compute_row_mode_end_index_sub",
606  {current_row_pos,
607  current_partition_start_offset_lv,
608  frame_end_bound_expr_lv});
609  } else {
610  CHECK(window_func->getFrameBoundType() ==
612  std::string upper_bound_func_name{"range_mode_"};
613  upper_bound_func_name.append(order_col_type_name);
614  upper_bound_func_name.append("_sub_frame_upper_bound");
615  frame_end_bound_lv = cgen_state_->emitCall(upper_bound_func_name,
616  {num_elem_current_partition_lv,
617  current_col_value_lv,
618  order_key_buf_ptr,
619  target_partition_rowid_ptr,
620  target_partition_sorted_rowid_ptr,
621  frame_end_bound_expr_lv,
622  order_key_col_null_val_lv,
623  null_start_pos_lv,
624  null_end_pos_lv});
625  }
626  } else if (frame_end_bound->getBoundType() == SqlWindowFrameBoundType::CURRENT_ROW) {
627  // frame ends at the current row
628  if (window_func->getFrameBoundType() ==
630  frame_end_bound_lv = cgen_state_->emitCall("compute_row_mode_end_index_sub",
631  {current_row_pos,
632  current_partition_start_offset_lv,
633  cgen_state_->llInt((int64_t)0)});
634  } else {
635  CHECK(window_func->getFrameBoundType() ==
637  std::string upper_bound_func_name{"compute_"};
638  upper_bound_func_name.append(order_col_type_name);
639  upper_bound_func_name.append("_upper_bound_from_ordered_index");
640  frame_end_bound_lv = cgen_state_->emitCall(upper_bound_func_name,
641  {num_elem_current_partition_lv,
642  current_col_value_lv,
643  order_key_buf_ptr,
644  target_partition_rowid_ptr,
645  target_partition_sorted_rowid_ptr,
646  order_key_col_null_val_lv,
647  null_start_pos_lv,
648  null_end_pos_lv});
649  }
650  } else if (frame_end_bound->getBoundType() ==
652  // frame ends at the position X rows after the current row
653  CHECK(frame_end_bound_expr_lv);
654  if (window_func->getFrameBoundType() ==
656  frame_end_bound_lv = cgen_state_->emitCall("compute_row_mode_end_index_add",
657  {current_row_pos,
658  current_partition_start_offset_lv,
659  frame_end_bound_expr_lv,
660  num_elem_current_partition_lv});
661  } else {
662  CHECK(window_func->getFrameBoundType() ==
664  std::string upper_bound_func_name{"range_mode_"};
665  upper_bound_func_name.append(order_col_type_name);
666  upper_bound_func_name.append("_add_frame_upper_bound");
667  frame_end_bound_lv = cgen_state_->emitCall(upper_bound_func_name,
668  {num_elem_current_partition_lv,
669  current_col_value_lv,
670  order_key_buf_ptr,
671  target_partition_rowid_ptr,
672  target_partition_sorted_rowid_ptr,
673  frame_end_bound_expr_lv,
674  order_key_col_null_val_lv,
675  null_start_pos_lv,
676  null_end_pos_lv});
677  }
678  } else {
679  // frame ends at the last row of the partition
680  CHECK(frame_end_bound->getBoundType() ==
682  frame_end_bound_lv = num_elem_current_partition_lv;
683  }
684 
685  // compute aggregated value over the computed frame range
686  CHECK(frame_start_bound_expr_lv);
687  CHECK(frame_end_bound_expr_lv);
688 
689  // codegen to send a query with frame bound to aggregate tree searcher
690  llvm::Value* aggregation_trees_lv{nullptr};
691  llvm::Value* invalid_val_lv{nullptr};
692  llvm::Value* null_val_lv{nullptr};
693  std::string aggregation_tree_search_func_name{"search_"};
694  std::string aggregation_tree_getter_func_name{"get_"};
695 
696  // prepare null values and aggregate_tree getter and searcher depending on
697  // a type of the ordering column
698  auto agg_expr_ti = args.front()->get_type_info();
699  switch (agg_expr_ti.get_type()) {
700  case SQLTypes::kTINYINT:
701  case SQLTypes::kSMALLINT:
702  case SQLTypes::kINT:
703  case SQLTypes::kBIGINT:
704  case SQLTypes::kNUMERIC:
705  case SQLTypes::kDECIMAL: {
706  if (window_func->getKind() == SqlWindowFunctionKind::MIN) {
707  invalid_val_lv = cgen_state_->llInt(std::numeric_limits<int64_t>::max());
708  } else if (window_func->getKind() == SqlWindowFunctionKind::MAX) {
709  invalid_val_lv = cgen_state_->llInt(std::numeric_limits<int64_t>::lowest());
710  } else {
711  invalid_val_lv = cgen_state_->llInt((int64_t)0);
712  }
713  null_val_lv = cgen_state_->llInt(inline_int_null_value<int64_t>());
714  aggregation_tree_search_func_name += "int64_t";
715  aggregation_tree_getter_func_name += "integer";
716  break;
717  }
718  case SQLTypes::kFLOAT:
719  case SQLTypes::kDOUBLE: {
720  if (window_func->getKind() == SqlWindowFunctionKind::MIN) {
721  invalid_val_lv = cgen_state_->llFp(std::numeric_limits<double>::max());
722  } else if (window_func->getKind() == SqlWindowFunctionKind::MAX) {
723  invalid_val_lv = cgen_state_->llFp(std::numeric_limits<double>::lowest());
724  } else {
725  invalid_val_lv = cgen_state_->llFp((double)0);
726  }
727  null_val_lv = cgen_state_->inlineFpNull(SQLTypeInfo(kDOUBLE));
728  aggregation_tree_search_func_name += "double";
729  aggregation_tree_getter_func_name += "double";
730  break;
731  }
732  default: {
733  CHECK(false);
734  break;
735  }
736  }
737 
738  // derived aggregation has a different code path
739  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
740  aggregation_tree_search_func_name += "_derived";
741  aggregation_tree_getter_func_name += "_derived";
742  }
743 
744  // get a buffer holding aggregate trees for each partition
745  if (agg_expr_ti.is_integer() || agg_expr_ti.is_decimal()) {
746  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
747  aggregation_trees_lv = cgen_state_->llInt(reinterpret_cast<int64_t>(
748  window_func_context->getDerivedAggregationTreesForIntegerTypeWindowExpr()));
749  } else {
750  aggregation_trees_lv = cgen_state_->llInt(reinterpret_cast<int64_t>(
751  window_func_context->getAggregationTreesForIntegerTypeWindowExpr()));
752  }
753  } else if (agg_expr_ti.is_fp()) {
754  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
755  aggregation_trees_lv = cgen_state_->llInt(reinterpret_cast<int64_t>(
756  window_func_context->getDerivedAggregationTreesForDoubleTypeWindowExpr()));
757  } else {
758  aggregation_trees_lv = cgen_state_->llInt(reinterpret_cast<int64_t>(
759  window_func_context->getAggregationTreesForDoubleTypeWindowExpr()));
760  }
761  }
762 
763  CHECK(aggregation_trees_lv);
764  CHECK(invalid_val_lv);
765  aggregation_tree_search_func_name += "_aggregation_tree";
766  aggregation_tree_getter_func_name += "_aggregation_tree";
767 
768  // get the aggregate tree of the current partition from a window context
769  auto aggregation_trees_ptr =
770  cgen_state_->ir_builder_.CreateIntToPtr(aggregation_trees_lv, ppi64_type);
771  auto target_aggregation_tree_lv = cgen_state_->emitCall(
772  aggregation_tree_getter_func_name, {aggregation_trees_ptr, partition_index_lv});
773 
774  // send a query to the aggregate tree with the frame range:
775  // `frame_start_bound_lv` ~ `frame_end_bound_lv`
776  auto res_lv =
777  cgen_state_->emitCall(aggregation_tree_search_func_name,
778  {target_aggregation_tree_lv,
779  frame_start_bound_lv,
780  frame_end_bound_lv,
781  current_partition_tree_depth_lv,
782  aggregation_tree_fanout_lv,
783  cgen_state_->llBool(agg_expr_ti.is_decimal()),
784  cgen_state_->llInt((int64_t)agg_expr_ti.get_scale()),
785  invalid_val_lv,
786  null_val_lv,
787  agg_type_lv});
788 
789  // handling returned null value if exists
790  std::string null_handler_func_name{"handle_null_val_"};
791  std::vector<llvm::Value*> null_handler_args{res_lv, null_val_lv};
792 
793  // determine null_handling function's name
794  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
795  // average aggregate function returns a value as a double
796  // (and our search* function also returns a double)
797  if (agg_expr_ti.is_fp()) {
798  // fp type: double null value
799  null_handler_func_name += "double_double";
800  } else {
801  // non-fp type: int64_t null type
802  null_handler_func_name += "double_int64_t";
803  }
804  } else if (agg_expr_ti.is_fp()) {
805  // fp type: double null value
806  null_handler_func_name += "double_double";
807  } else {
808  // non-fp type: int64_t null type
809  null_handler_func_name += "int64_t_int64_t";
810  }
811  null_handler_func_name += "_window_framing_agg";
812 
813  // prepare null_val
814  if (window_func->getKind() == SqlWindowFunctionKind::COUNT) {
815  if (agg_expr_ti.is_fp()) {
816  null_handler_args.push_back(cgen_state_->llFp((double)0));
817  } else {
818  null_handler_args.push_back(cgen_state_->llInt((int64_t)0));
819  }
820  } else if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
821  null_handler_args.push_back(cgen_state_->inlineFpNull(SQLTypeInfo(kDOUBLE)));
822  } else {
823  null_handler_args.push_back(cgen_state_->castToTypeIn(window_func_null_val, 64));
824  }
825  res_lv = cgen_state_->emitCall(null_handler_func_name, null_handler_args);
826 
827  // when AGG_TYPE is double, we get a double type return value we expect an integer
828  // type value for the count aggregation
829  if (window_func->getKind() == SqlWindowFunctionKind::COUNT && agg_expr_ti.is_fp()) {
830  return cgen_state_->ir_builder_.CreateFPToSI(
831  res_lv, get_int_type(64, cgen_state_->context_));
832  }
833  return res_lv;
834  } else {
835  llvm::Value* multiplicity_lv = nullptr;
836  if (args.empty()) {
837  cgen_state_->emitCall(agg_name, {aggregate_state, crt_val});
838  } else {
839  cgen_state_->emitCall(agg_name + "_skip_val",
840  {aggregate_state, crt_val, window_func_null_val});
841  }
842  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
843  codegenWindowAvgEpilogue(crt_val, window_func_null_val, multiplicity_lv);
844  }
845  return codegenAggregateWindowState();
846  }
847 }
848 
849 void Executor::codegenWindowAvgEpilogue(llvm::Value* crt_val,
850  llvm::Value* window_func_null_val,
851  llvm::Value* multiplicity_lv) {
852  AUTOMATIC_IR_METADATA(cgen_state_.get());
853  const auto window_func_context =
855  const auto window_func = window_func_context->getWindowFunction();
856  const auto window_func_ti = get_adjusted_window_type_info(window_func);
857  const auto pi32_type =
858  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
859  const auto pi64_type =
860  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
861  const auto aggregate_state_type =
862  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
863  const auto aggregate_state_count_i64 = cgen_state_->llInt(
864  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
865  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
866  aggregate_state_count_i64, aggregate_state_type);
867  std::string agg_count_func_name = "agg_count";
868  switch (window_func_ti.get_type()) {
869  case kFLOAT: {
870  agg_count_func_name += "_float";
871  break;
872  }
873  case kDOUBLE: {
874  agg_count_func_name += "_double";
875  break;
876  }
877  default: {
878  break;
879  }
880  }
881  agg_count_func_name += "_skip_val";
882  cgen_state_->emitCall(agg_count_func_name,
883  {aggregate_state_count, crt_val, window_func_null_val});
884 }
885 
887  AUTOMATIC_IR_METADATA(cgen_state_.get());
888  const auto pi32_type =
889  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
890  const auto pi64_type =
891  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
892  const auto window_func_context =
894  const Analyzer::WindowFunction* window_func = window_func_context->getWindowFunction();
895  const auto window_func_ti = get_adjusted_window_type_info(window_func);
896  const auto aggregate_state_type =
897  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
898  auto aggregate_state = aggregateWindowStatePtr();
899  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
900  const auto aggregate_state_count_i64 = cgen_state_->llInt(
901  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
902  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
903  aggregate_state_count_i64, aggregate_state_type);
904  const auto double_null_lv = cgen_state_->inlineFpNull(SQLTypeInfo(kDOUBLE));
905  switch (window_func_ti.get_type()) {
906  case kFLOAT: {
907  return cgen_state_->emitCall(
908  "load_avg_float", {aggregate_state, aggregate_state_count, double_null_lv});
909  }
910  case kDOUBLE: {
911  return cgen_state_->emitCall(
912  "load_avg_double", {aggregate_state, aggregate_state_count, double_null_lv});
913  }
914  case kDECIMAL: {
915  return cgen_state_->emitCall(
916  "load_avg_decimal",
917  {aggregate_state,
918  aggregate_state_count,
919  double_null_lv,
920  cgen_state_->llInt<int32_t>(window_func_ti.get_scale())});
921  }
922  default: {
923  return cgen_state_->emitCall(
924  "load_avg_int", {aggregate_state, aggregate_state_count, double_null_lv});
925  }
926  }
927  }
928  if (window_func->getKind() == SqlWindowFunctionKind::COUNT) {
929  return cgen_state_->ir_builder_.CreateLoad(
930  aggregate_state->getType()->getPointerElementType(), aggregate_state);
931  }
932  switch (window_func_ti.get_type()) {
933  case kFLOAT: {
934  return cgen_state_->emitCall("load_float", {aggregate_state});
935  }
936  case kDOUBLE: {
937  return cgen_state_->emitCall("load_double", {aggregate_state});
938  }
939  default: {
940  return cgen_state_->ir_builder_.CreateLoad(
941  aggregate_state->getType()->getPointerElementType(), aggregate_state);
942  }
943  }
944 }
#define CHECK_EQ(x, y)
Definition: Logger.h:230
SqlWindowFunctionKind getKind() const
Definition: Analyzer.h:2297
SQLTypes
Definition: sqltypes.h:38
#define LOG(tag)
Definition: Logger.h:216
llvm::Value * posArg(const Analyzer::Expr *) const
Definition: ColumnIR.cpp:515
#define UNREACHABLE()
Definition: Logger.h:266
llvm::Value * aggregateWindowStatePtr()
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:329
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
static WindowFunctionContext * getActiveWindowFunctionContext(Executor *executor)
std::string get_window_agg_name(const SqlWindowFunctionKind kind, const SQLTypeInfo &window_func_ti)
void codegenWindowAvgEpilogue(llvm::Value *crt_val, llvm::Value *window_func_null_val, llvm::Value *multiplicity_lv)
llvm::Value * codegenWindowPosition(const WindowFunctionContext *window_func_context, llvm::Value *pos_arg)
Definition: ColumnIR.cpp:226
static const WindowProjectNodeContext * get(Executor *executor)
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs() const
Definition: Analyzer.h:2299
const WindowFunctionContext * activateWindowFunctionContext(Executor *executor, const size_t target_index) const
llvm::Value * codegenCastBetweenIntTypes(llvm::Value *operand_lv, const SQLTypeInfo &operand_ti, const SQLTypeInfo &ti, bool upscale=true)
Definition: CastIR.cpp:318
#define AUTOMATIC_IR_METADATA(CGENSTATE)
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:81
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:30
void codegenWindowFunctionStateInit(llvm::Value *aggregate_state)
llvm::Value * toBool(llvm::Value *)
Definition: LogicalIR.cpp:343
SqlWindowFunctionKind
Definition: sqldefs.h:110
llvm::Value * codegenAggregateWindowState()
llvm::Value * codegenWindowFunctionAggregate(const CompilationOptions &co)
#define CHECK(condition)
Definition: Logger.h:222
const Analyzer::WindowFunction * getWindowFunction() const
Definition: sqltypes.h:45
llvm::Value * codegenWindowFunctionAggregateCalls(llvm::Value *aggregate_state, const CompilationOptions &co)
llvm::Value * codegenWindowFunction(const size_t target_index, const CompilationOptions &co)
llvm::BasicBlock * codegenWindowResetStateControlFlow()
SQLTypeInfo get_adjusted_window_type_info(const Analyzer::WindowFunction *window_func)