OmniSciDB  04ee39c94c
WindowFunctionIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "WindowContext.h"
20 
21 llvm::Value* Executor::codegenWindowFunction(const size_t target_index,
22  const CompilationOptions& co) {
23  CodeGenerator code_generator(this);
24  const auto window_func_context =
26  const auto window_func = window_func_context->getWindowFunction();
27  switch (window_func->getKind()) {
32  return cgen_state_->emitCall("row_number_window_func",
33  {cgen_state_->llInt(reinterpret_cast<const int64_t>(
34  window_func_context->output())),
35  code_generator.posArg(nullptr)});
36  }
39  return cgen_state_->emitCall("percent_window_func",
40  {cgen_state_->llInt(reinterpret_cast<const int64_t>(
41  window_func_context->output())),
42  code_generator.posArg(nullptr)});
43  }
49  const auto& args = window_func->getArgs();
50  CHECK(!args.empty());
51  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
52  CHECK_EQ(arg_lvs.size(), size_t(1));
53  return arg_lvs.front();
54  }
60  return codegenWindowFunctionAggregate(co);
61  }
62  default: {
63  LOG(FATAL) << "Invalid window function kind";
64  }
65  }
66  return nullptr;
67 }
68 
69 namespace {
70 
72  const SQLTypeInfo& window_func_ti) {
73  std::string agg_name;
74  switch (kind) {
76  agg_name = "agg_min";
77  break;
78  }
80  agg_name = "agg_max";
81  break;
82  }
85  agg_name = "agg_sum";
86  break;
87  }
89  agg_name = "agg_count";
90  break;
91  }
92  default: {
93  LOG(FATAL) << "Invalid window function kind";
94  }
95  }
96  switch (window_func_ti.get_type()) {
97  case kFLOAT: {
98  agg_name += "_float";
99  break;
100  }
101  case kDOUBLE: {
102  agg_name += "_double";
103  break;
104  }
105  default: {
106  break;
107  }
108  }
109  return agg_name;
110 }
111 
113  const auto& args = window_func->getArgs();
114  return ((window_func->getKind() == SqlWindowFunctionKind::COUNT && !args.empty()) ||
115  window_func->getKind() == SqlWindowFunctionKind::AVG)
116  ? args.front()->get_type_info()
117  : window_func->get_type_info();
118 }
119 
120 } // namespace
121 
123  const auto window_func_context =
125  const auto window_func = window_func_context->getWindowFunction();
126  const auto arg_ti = get_adjusted_window_type_info(window_func);
127  llvm::Type* aggregate_state_type =
128  arg_ti.get_type() == kFLOAT
129  ? llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0)
130  : llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
131  const auto aggregate_state_i64 = cgen_state_->llInt(
132  reinterpret_cast<const int64_t>(window_func_context->aggregateState()));
133  return cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_i64,
134  aggregate_state_type);
135 }
136 
138  const auto reset_state_false_bb = codegenWindowResetStateControlFlow();
139  auto aggregate_state = aggregateWindowStatePtr();
140  llvm::Value* aggregate_state_count = nullptr;
141  const auto window_func_context =
143  const auto window_func = window_func_context->getWindowFunction();
144  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
145  const auto aggregate_state_count_i64 = cgen_state_->llInt(
146  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
147  const auto pi64_type =
148  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
149  aggregate_state_count =
150  cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_count_i64, pi64_type);
151  }
152  codegenWindowFunctionStateInit(aggregate_state);
153  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
154  const auto count_zero = cgen_state_->llInt(int64_t(0));
155  cgen_state_->emitCall("agg_id", {aggregate_state_count, count_zero});
156  }
157  cgen_state_->ir_builder_.CreateBr(reset_state_false_bb);
158  cgen_state_->ir_builder_.SetInsertPoint(reset_state_false_bb);
160  return codegenWindowFunctionAggregateCalls(aggregate_state, co);
161 }
162 
164  const auto window_func_context =
166  const auto bitset = cgen_state_->llInt(
167  reinterpret_cast<const int64_t>(window_func_context->partitionStart()));
168  const auto min_val = cgen_state_->llInt(int64_t(0));
169  const auto max_val = cgen_state_->llInt(window_func_context->elementCount() - 1);
170  const auto null_val = cgen_state_->llInt(inline_int_null_value<int64_t>());
171  const auto null_bool_val = cgen_state_->llInt<int8_t>(inline_int_null_value<int8_t>());
172  CodeGenerator code_generator(this);
173  const auto reset_state =
174  code_generator.toBool(cgen_state_->emitCall("bit_is_set",
175  {bitset,
176  code_generator.posArg(nullptr),
177  min_val,
178  max_val,
179  null_val,
180  null_bool_val}));
181  const auto reset_state_true_bb = llvm::BasicBlock::Create(
182  cgen_state_->context_, "reset_state.true", cgen_state_->row_func_);
183  const auto reset_state_false_bb = llvm::BasicBlock::Create(
184  cgen_state_->context_, "reset_state.false", cgen_state_->row_func_);
185  cgen_state_->ir_builder_.CreateCondBr(
186  reset_state, reset_state_true_bb, reset_state_false_bb);
187  cgen_state_->ir_builder_.SetInsertPoint(reset_state_true_bb);
188  return reset_state_false_bb;
189 }
190 
191 void Executor::codegenWindowFunctionStateInit(llvm::Value* aggregate_state) {
192  const auto window_func_context =
194  const auto window_func = window_func_context->getWindowFunction();
195  const auto window_func_ti = get_adjusted_window_type_info(window_func);
196  const auto window_func_null_val =
197  window_func_ti.is_fp()
198  ? cgen_state_->inlineFpNull(window_func_ti)
199  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
200  llvm::Value* window_func_init_val;
201  if (window_func_context->getWindowFunction()->getKind() ==
203  switch (window_func_ti.get_type()) {
204  case kFLOAT: {
205  window_func_init_val = cgen_state_->llFp(float(0));
206  break;
207  }
208  case kDOUBLE: {
209  window_func_init_val = cgen_state_->llFp(double(0));
210  break;
211  }
212  default: {
213  window_func_init_val = cgen_state_->llInt(int64_t(0));
214  break;
215  }
216  }
217  } else {
218  window_func_init_val = window_func_null_val;
219  }
220  const auto pi32_type =
221  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
222  switch (window_func_ti.get_type()) {
223  case kDOUBLE: {
224  cgen_state_->emitCall("agg_id_double", {aggregate_state, window_func_init_val});
225  break;
226  }
227  case kFLOAT: {
228  aggregate_state =
229  cgen_state_->ir_builder_.CreateBitCast(aggregate_state, pi32_type);
230  cgen_state_->emitCall("agg_id_float", {aggregate_state, window_func_init_val});
231  break;
232  }
233  default: {
234  cgen_state_->emitCall("agg_id", {aggregate_state, window_func_init_val});
235  break;
236  }
237  }
238 }
239 
240 llvm::Value* Executor::codegenWindowFunctionAggregateCalls(llvm::Value* aggregate_state,
241  const CompilationOptions& co) {
242  const auto window_func_context =
244  const auto window_func = window_func_context->getWindowFunction();
245  const auto window_func_ti = get_adjusted_window_type_info(window_func);
246  const auto window_func_null_val =
247  window_func_ti.is_fp()
248  ? cgen_state_->inlineFpNull(window_func_ti)
249  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
250  const auto& args = window_func->getArgs();
251  llvm::Value* crt_val;
252  if (args.empty()) {
253  CHECK(window_func->getKind() == SqlWindowFunctionKind::COUNT);
254  crt_val = cgen_state_->llInt(int64_t(1));
255  } else {
256  CodeGenerator code_generator(this);
257  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
258  CHECK_EQ(arg_lvs.size(), size_t(1));
259  if (window_func->getKind() == SqlWindowFunctionKind::SUM && !window_func_ti.is_fp()) {
260  crt_val = code_generator.codegenCastBetweenIntTypes(
261  arg_lvs.front(), args.front()->get_type_info(), window_func_ti, false);
262  } else {
263  crt_val = window_func_ti.get_type() == kFLOAT
264  ? arg_lvs.front()
265  : cgen_state_->castToTypeIn(arg_lvs.front(), 64);
266  }
267  }
268  const auto agg_name = get_window_agg_name(window_func->getKind(), window_func_ti);
269  llvm::Value* multiplicity_lv = nullptr;
270  if (args.empty()) {
271  cgen_state_->emitCall(agg_name, {aggregate_state, crt_val});
272  } else {
273  cgen_state_->emitCall(agg_name + "_skip_val",
274  {aggregate_state, crt_val, window_func_null_val});
275  }
276  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
277  codegenWindowAvgEpilogue(crt_val, window_func_null_val, multiplicity_lv);
278  }
279  return codegenAggregateWindowState();
280 }
281 
282 void Executor::codegenWindowAvgEpilogue(llvm::Value* crt_val,
283  llvm::Value* window_func_null_val,
284  llvm::Value* multiplicity_lv) {
285  const auto window_func_context =
287  const auto window_func = window_func_context->getWindowFunction();
288  const auto window_func_ti = get_adjusted_window_type_info(window_func);
289  const auto pi32_type =
290  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
291  const auto pi64_type =
292  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
293  const auto aggregate_state_type =
294  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
295  const auto aggregate_state_count_i64 = cgen_state_->llInt(
296  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
297  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
298  aggregate_state_count_i64, aggregate_state_type);
299  std::string agg_count_func_name = "agg_count";
300  switch (window_func_ti.get_type()) {
301  case kFLOAT: {
302  agg_count_func_name += "_float";
303  break;
304  }
305  case kDOUBLE: {
306  agg_count_func_name += "_double";
307  break;
308  }
309  default: {
310  break;
311  }
312  }
313  agg_count_func_name += "_skip_val";
314  cgen_state_->emitCall(agg_count_func_name,
315  {aggregate_state_count, crt_val, window_func_null_val});
316 }
317 
319  const auto pi32_type =
320  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
321  const auto pi64_type =
322  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
323  const auto window_func_context =
325  const Analyzer::WindowFunction* window_func = window_func_context->getWindowFunction();
326  const auto window_func_ti = get_adjusted_window_type_info(window_func);
327  const auto aggregate_state_type =
328  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
329  auto aggregate_state = aggregateWindowStatePtr();
330  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
331  const auto aggregate_state_count_i64 = cgen_state_->llInt(
332  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
333  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
334  aggregate_state_count_i64, aggregate_state_type);
335  const auto double_null_lv = cgen_state_->inlineFpNull(SQLTypeInfo(kDOUBLE));
336  switch (window_func_ti.get_type()) {
337  case kFLOAT: {
338  return cgen_state_->emitCall(
339  "load_avg_float", {aggregate_state, aggregate_state_count, double_null_lv});
340  }
341  case kDOUBLE: {
342  return cgen_state_->emitCall(
343  "load_avg_double", {aggregate_state, aggregate_state_count, double_null_lv});
344  }
345  case kDECIMAL: {
346  return cgen_state_->emitCall(
347  "load_avg_decimal",
348  {aggregate_state,
349  aggregate_state_count,
350  double_null_lv,
351  cgen_state_->llInt<int32_t>(window_func_ti.get_scale())});
352  }
353  default: {
354  return cgen_state_->emitCall(
355  "load_avg_int", {aggregate_state, aggregate_state_count, double_null_lv});
356  }
357  }
358  }
359  if (window_func->getKind() == SqlWindowFunctionKind::COUNT) {
360  return cgen_state_->ir_builder_.CreateLoad(aggregate_state);
361  }
362  switch (window_func_ti.get_type()) {
363  case kFLOAT: {
364  return cgen_state_->emitCall("load_float", {aggregate_state});
365  }
366  case kDOUBLE: {
367  return cgen_state_->emitCall("load_double", {aggregate_state});
368  }
369  default: {
370  return cgen_state_->ir_builder_.CreateLoad(aggregate_state);
371  }
372  }
373 }
#define CHECK_EQ(x, y)
Definition: Logger.h:195
#define LOG(tag)
Definition: Logger.h:182
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:323
llvm::Value * aggregateWindowStatePtr()
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
std::string get_window_agg_name(const SqlWindowFunctionKind kind, const SQLTypeInfo &window_func_ti)
void codegenWindowAvgEpilogue(llvm::Value *crt_val, llvm::Value *window_func_null_val, llvm::Value *multiplicity_lv)
const Analyzer::WindowFunction * getWindowFunction() const
static const WindowProjectNodeContext * get()
llvm::Value * codegenCastBetweenIntTypes(llvm::Value *operand_lv, const SQLTypeInfo &operand_ti, const SQLTypeInfo &ti, bool upscale=true)
Definition: CastIR.cpp:228
SQLTypeInfoCore< ArrayContextTypeSizer, ExecutorTypePackaging, DateTimeFacilities > SQLTypeInfo
Definition: sqltypes.h:823
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:25
void codegenWindowFunctionStateInit(llvm::Value *aggregate_state)
llvm::Value * toBool(llvm::Value *)
Definition: LogicalIR.cpp:333
SqlWindowFunctionKind
Definition: sqldefs.h:73
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs() const
Definition: Analyzer.h:1341
llvm::Value * codegenAggregateWindowState()
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:77
llvm::Value * codegenWindowFunctionAggregate(const CompilationOptions &co)
const WindowFunctionContext * activateWindowFunctionContext(const size_t target_index) const
llvm::Value * posArg(const Analyzer::Expr *) const
Definition: ColumnIR.cpp:503
#define CHECK(condition)
Definition: Logger.h:187
llvm::Value * codegenWindowFunctionAggregateCalls(llvm::Value *aggregate_state, const CompilationOptions &co)
SqlWindowFunctionKind getKind() const
Definition: Analyzer.h:1339
llvm::Value * codegenWindowFunction(const size_t target_index, const CompilationOptions &co)
static WindowFunctionContext * getActiveWindowFunctionContext()
llvm::BasicBlock * codegenWindowResetStateControlFlow()
SQLTypeInfo get_adjusted_window_type_info(const Analyzer::WindowFunction *window_func)