OmniSciDB  72180abbfe
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
WindowFunctionIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "WindowContext.h"
20 
21 llvm::Value* Executor::codegenWindowFunction(const size_t target_index,
22  const CompilationOptions& co) {
23  CodeGenerator code_generator(this);
24  const auto window_func_context =
26  target_index);
27  const auto window_func = window_func_context->getWindowFunction();
28  switch (window_func->getKind()) {
33  return cgen_state_->emitCall("row_number_window_func",
34  {cgen_state_->llInt(reinterpret_cast<const int64_t>(
35  window_func_context->output())),
36  code_generator.posArg(nullptr)});
37  }
40  return cgen_state_->emitCall("percent_window_func",
41  {cgen_state_->llInt(reinterpret_cast<const int64_t>(
42  window_func_context->output())),
43  code_generator.posArg(nullptr)});
44  }
50  const auto& args = window_func->getArgs();
51  CHECK(!args.empty());
52  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
53  CHECK_EQ(arg_lvs.size(), size_t(1));
54  return arg_lvs.front();
55  }
61  return codegenWindowFunctionAggregate(co);
62  }
63  default: {
64  LOG(FATAL) << "Invalid window function kind";
65  }
66  }
67  return nullptr;
68 }
69 
70 namespace {
71 
73  const SQLTypeInfo& window_func_ti) {
74  std::string agg_name;
75  switch (kind) {
77  agg_name = "agg_min";
78  break;
79  }
81  agg_name = "agg_max";
82  break;
83  }
86  agg_name = "agg_sum";
87  break;
88  }
90  agg_name = "agg_count";
91  break;
92  }
93  default: {
94  LOG(FATAL) << "Invalid window function kind";
95  }
96  }
97  switch (window_func_ti.get_type()) {
98  case kFLOAT: {
99  agg_name += "_float";
100  break;
101  }
102  case kDOUBLE: {
103  agg_name += "_double";
104  break;
105  }
106  default: {
107  break;
108  }
109  }
110  return agg_name;
111 }
112 
114  const auto& args = window_func->getArgs();
115  return ((window_func->getKind() == SqlWindowFunctionKind::COUNT && !args.empty()) ||
116  window_func->getKind() == SqlWindowFunctionKind::AVG)
117  ? args.front()->get_type_info()
118  : window_func->get_type_info();
119 }
120 
121 } // namespace
122 
124  const auto window_func_context =
126  const auto window_func = window_func_context->getWindowFunction();
127  const auto arg_ti = get_adjusted_window_type_info(window_func);
128  llvm::Type* aggregate_state_type =
129  arg_ti.get_type() == kFLOAT
130  ? llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0)
131  : llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
132  const auto aggregate_state_i64 = cgen_state_->llInt(
133  reinterpret_cast<const int64_t>(window_func_context->aggregateState()));
134  return cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_i64,
135  aggregate_state_type);
136 }
137 
139  const auto reset_state_false_bb = codegenWindowResetStateControlFlow();
140  auto aggregate_state = aggregateWindowStatePtr();
141  llvm::Value* aggregate_state_count = nullptr;
142  const auto window_func_context =
144  const auto window_func = window_func_context->getWindowFunction();
145  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
146  const auto aggregate_state_count_i64 = cgen_state_->llInt(
147  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
148  const auto pi64_type =
149  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
150  aggregate_state_count =
151  cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_count_i64, pi64_type);
152  }
153  codegenWindowFunctionStateInit(aggregate_state);
154  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
155  const auto count_zero = cgen_state_->llInt(int64_t(0));
156  cgen_state_->emitCall("agg_id", {aggregate_state_count, count_zero});
157  }
158  cgen_state_->ir_builder_.CreateBr(reset_state_false_bb);
159  cgen_state_->ir_builder_.SetInsertPoint(reset_state_false_bb);
161  return codegenWindowFunctionAggregateCalls(aggregate_state, co);
162 }
163 
165  const auto window_func_context =
167  const auto bitset = cgen_state_->llInt(
168  reinterpret_cast<const int64_t>(window_func_context->partitionStart()));
169  const auto min_val = cgen_state_->llInt(int64_t(0));
170  const auto max_val = cgen_state_->llInt(window_func_context->elementCount() - 1);
171  const auto null_val = cgen_state_->llInt(inline_int_null_value<int64_t>());
172  const auto null_bool_val = cgen_state_->llInt<int8_t>(inline_int_null_value<int8_t>());
173  CodeGenerator code_generator(this);
174  const auto reset_state =
175  code_generator.toBool(cgen_state_->emitCall("bit_is_set",
176  {bitset,
177  code_generator.posArg(nullptr),
178  min_val,
179  max_val,
180  null_val,
181  null_bool_val}));
182  const auto reset_state_true_bb = llvm::BasicBlock::Create(
183  cgen_state_->context_, "reset_state.true", cgen_state_->row_func_);
184  const auto reset_state_false_bb = llvm::BasicBlock::Create(
185  cgen_state_->context_, "reset_state.false", cgen_state_->row_func_);
186  cgen_state_->ir_builder_.CreateCondBr(
187  reset_state, reset_state_true_bb, reset_state_false_bb);
188  cgen_state_->ir_builder_.SetInsertPoint(reset_state_true_bb);
189  return reset_state_false_bb;
190 }
191 
192 void Executor::codegenWindowFunctionStateInit(llvm::Value* aggregate_state) {
193  const auto window_func_context =
195  const auto window_func = window_func_context->getWindowFunction();
196  const auto window_func_ti = get_adjusted_window_type_info(window_func);
197  const auto window_func_null_val =
198  window_func_ti.is_fp()
199  ? cgen_state_->inlineFpNull(window_func_ti)
200  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
201  llvm::Value* window_func_init_val;
202  if (window_func_context->getWindowFunction()->getKind() ==
204  switch (window_func_ti.get_type()) {
205  case kFLOAT: {
206  window_func_init_val = cgen_state_->llFp(float(0));
207  break;
208  }
209  case kDOUBLE: {
210  window_func_init_val = cgen_state_->llFp(double(0));
211  break;
212  }
213  default: {
214  window_func_init_val = cgen_state_->llInt(int64_t(0));
215  break;
216  }
217  }
218  } else {
219  window_func_init_val = window_func_null_val;
220  }
221  const auto pi32_type =
222  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
223  switch (window_func_ti.get_type()) {
224  case kDOUBLE: {
225  cgen_state_->emitCall("agg_id_double", {aggregate_state, window_func_init_val});
226  break;
227  }
228  case kFLOAT: {
229  aggregate_state =
230  cgen_state_->ir_builder_.CreateBitCast(aggregate_state, pi32_type);
231  cgen_state_->emitCall("agg_id_float", {aggregate_state, window_func_init_val});
232  break;
233  }
234  default: {
235  cgen_state_->emitCall("agg_id", {aggregate_state, window_func_init_val});
236  break;
237  }
238  }
239 }
240 
241 llvm::Value* Executor::codegenWindowFunctionAggregateCalls(llvm::Value* aggregate_state,
242  const CompilationOptions& co) {
243  const auto window_func_context =
245  const auto window_func = window_func_context->getWindowFunction();
246  const auto window_func_ti = get_adjusted_window_type_info(window_func);
247  const auto window_func_null_val =
248  window_func_ti.is_fp()
249  ? cgen_state_->inlineFpNull(window_func_ti)
250  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
251  const auto& args = window_func->getArgs();
252  llvm::Value* crt_val;
253  if (args.empty()) {
254  CHECK(window_func->getKind() == SqlWindowFunctionKind::COUNT);
255  crt_val = cgen_state_->llInt(int64_t(1));
256  } else {
257  CodeGenerator code_generator(this);
258  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
259  CHECK_EQ(arg_lvs.size(), size_t(1));
260  if (window_func->getKind() == SqlWindowFunctionKind::SUM && !window_func_ti.is_fp()) {
261  crt_val = code_generator.codegenCastBetweenIntTypes(
262  arg_lvs.front(), args.front()->get_type_info(), window_func_ti, false);
263  } else {
264  crt_val = window_func_ti.get_type() == kFLOAT
265  ? arg_lvs.front()
266  : cgen_state_->castToTypeIn(arg_lvs.front(), 64);
267  }
268  }
269  const auto agg_name = get_window_agg_name(window_func->getKind(), window_func_ti);
270  llvm::Value* multiplicity_lv = nullptr;
271  if (args.empty()) {
272  cgen_state_->emitCall(agg_name, {aggregate_state, crt_val});
273  } else {
274  cgen_state_->emitCall(agg_name + "_skip_val",
275  {aggregate_state, crt_val, window_func_null_val});
276  }
277  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
278  codegenWindowAvgEpilogue(crt_val, window_func_null_val, multiplicity_lv);
279  }
280  return codegenAggregateWindowState();
281 }
282 
283 void Executor::codegenWindowAvgEpilogue(llvm::Value* crt_val,
284  llvm::Value* window_func_null_val,
285  llvm::Value* multiplicity_lv) {
286  const auto window_func_context =
288  const auto window_func = window_func_context->getWindowFunction();
289  const auto window_func_ti = get_adjusted_window_type_info(window_func);
290  const auto pi32_type =
291  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
292  const auto pi64_type =
293  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
294  const auto aggregate_state_type =
295  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
296  const auto aggregate_state_count_i64 = cgen_state_->llInt(
297  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
298  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
299  aggregate_state_count_i64, aggregate_state_type);
300  std::string agg_count_func_name = "agg_count";
301  switch (window_func_ti.get_type()) {
302  case kFLOAT: {
303  agg_count_func_name += "_float";
304  break;
305  }
306  case kDOUBLE: {
307  agg_count_func_name += "_double";
308  break;
309  }
310  default: {
311  break;
312  }
313  }
314  agg_count_func_name += "_skip_val";
315  cgen_state_->emitCall(agg_count_func_name,
316  {aggregate_state_count, crt_val, window_func_null_val});
317 }
318 
320  const auto pi32_type =
321  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
322  const auto pi64_type =
323  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
324  const auto window_func_context =
326  const Analyzer::WindowFunction* window_func = window_func_context->getWindowFunction();
327  const auto window_func_ti = get_adjusted_window_type_info(window_func);
328  const auto aggregate_state_type =
329  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
330  auto aggregate_state = aggregateWindowStatePtr();
331  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
332  const auto aggregate_state_count_i64 = cgen_state_->llInt(
333  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
334  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
335  aggregate_state_count_i64, aggregate_state_type);
336  const auto double_null_lv = cgen_state_->inlineFpNull(SQLTypeInfo(kDOUBLE));
337  switch (window_func_ti.get_type()) {
338  case kFLOAT: {
339  return cgen_state_->emitCall(
340  "load_avg_float", {aggregate_state, aggregate_state_count, double_null_lv});
341  }
342  case kDOUBLE: {
343  return cgen_state_->emitCall(
344  "load_avg_double", {aggregate_state, aggregate_state_count, double_null_lv});
345  }
346  case kDECIMAL: {
347  return cgen_state_->emitCall(
348  "load_avg_decimal",
349  {aggregate_state,
350  aggregate_state_count,
351  double_null_lv,
352  cgen_state_->llInt<int32_t>(window_func_ti.get_scale())});
353  }
354  default: {
355  return cgen_state_->emitCall(
356  "load_avg_int", {aggregate_state, aggregate_state_count, double_null_lv});
357  }
358  }
359  }
360  if (window_func->getKind() == SqlWindowFunctionKind::COUNT) {
361  return cgen_state_->ir_builder_.CreateLoad(aggregate_state);
362  }
363  switch (window_func_ti.get_type()) {
364  case kFLOAT: {
365  return cgen_state_->emitCall("load_float", {aggregate_state});
366  }
367  case kDOUBLE: {
368  return cgen_state_->emitCall("load_double", {aggregate_state});
369  }
370  default: {
371  return cgen_state_->ir_builder_.CreateLoad(aggregate_state);
372  }
373  }
374 }
#define CHECK_EQ(x, y)
Definition: Logger.h:205
SqlWindowFunctionKind getKind() const
Definition: Analyzer.h:1403
#define LOG(tag)
Definition: Logger.h:188
llvm::Value * posArg(const Analyzer::Expr *) const
Definition: ColumnIR.cpp:507
llvm::Value * aggregateWindowStatePtr()
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:257
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
static WindowFunctionContext * getActiveWindowFunctionContext(Executor *executor)
std::string get_window_agg_name(const SqlWindowFunctionKind kind, const SQLTypeInfo &window_func_ti)
void codegenWindowAvgEpilogue(llvm::Value *crt_val, llvm::Value *window_func_null_val, llvm::Value *multiplicity_lv)
static const WindowProjectNodeContext * get(Executor *executor)
CHECK(cgen_state)
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs() const
Definition: Analyzer.h:1405
const WindowFunctionContext * activateWindowFunctionContext(Executor *executor, const size_t target_index) const
llvm::Value * codegenCastBetweenIntTypes(llvm::Value *operand_lv, const SQLTypeInfo &operand_ti, const SQLTypeInfo &ti, bool upscale=true)
Definition: CastIR.cpp:244
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:79
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:26
void codegenWindowFunctionStateInit(llvm::Value *aggregate_state)
llvm::Value * toBool(llvm::Value *)
Definition: LogicalIR.cpp:333
SqlWindowFunctionKind
Definition: sqldefs.h:82
llvm::Value * codegenAggregateWindowState()
llvm::Value * codegenWindowFunctionAggregate(const CompilationOptions &co)
const Analyzer::WindowFunction * getWindowFunction() const
llvm::Value * codegenWindowFunctionAggregateCalls(llvm::Value *aggregate_state, const CompilationOptions &co)
llvm::Value * codegenWindowFunction(const size_t target_index, const CompilationOptions &co)
llvm::BasicBlock * codegenWindowResetStateControlFlow()
SQLTypeInfo get_adjusted_window_type_info(const Analyzer::WindowFunction *window_func)