OmniSciDB  dfae7c3b14
WindowFunctionIR.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 OmniSci, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "CodeGenerator.h"
18 #include "Execute.h"
19 #include "WindowContext.h"
20 
21 llvm::Value* Executor::codegenWindowFunction(const size_t target_index,
22  const CompilationOptions& co) {
23  AUTOMATIC_IR_METADATA(cgen_state_.get());
24  CodeGenerator code_generator(this);
25  const auto window_func_context =
27  target_index);
28  const auto window_func = window_func_context->getWindowFunction();
29  switch (window_func->getKind()) {
34  return cgen_state_->emitCall("row_number_window_func",
35  {cgen_state_->llInt(reinterpret_cast<const int64_t>(
36  window_func_context->output())),
37  code_generator.posArg(nullptr)});
38  }
41  return cgen_state_->emitCall("percent_window_func",
42  {cgen_state_->llInt(reinterpret_cast<const int64_t>(
43  window_func_context->output())),
44  code_generator.posArg(nullptr)});
45  }
51  const auto& args = window_func->getArgs();
52  CHECK(!args.empty());
53  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
54  CHECK_EQ(arg_lvs.size(), size_t(1));
55  return arg_lvs.front();
56  }
62  return codegenWindowFunctionAggregate(co);
63  }
64  default: {
65  LOG(FATAL) << "Invalid window function kind";
66  }
67  }
68  return nullptr;
69 }
70 
71 namespace {
72 
74  const SQLTypeInfo& window_func_ti) {
75  std::string agg_name;
76  switch (kind) {
78  agg_name = "agg_min";
79  break;
80  }
82  agg_name = "agg_max";
83  break;
84  }
87  agg_name = "agg_sum";
88  break;
89  }
91  agg_name = "agg_count";
92  break;
93  }
94  default: {
95  LOG(FATAL) << "Invalid window function kind";
96  }
97  }
98  switch (window_func_ti.get_type()) {
99  case kFLOAT: {
100  agg_name += "_float";
101  break;
102  }
103  case kDOUBLE: {
104  agg_name += "_double";
105  break;
106  }
107  default: {
108  break;
109  }
110  }
111  return agg_name;
112 }
113 
115  const auto& args = window_func->getArgs();
116  return ((window_func->getKind() == SqlWindowFunctionKind::COUNT && !args.empty()) ||
117  window_func->getKind() == SqlWindowFunctionKind::AVG)
118  ? args.front()->get_type_info()
119  : window_func->get_type_info();
120 }
121 
122 } // namespace
123 
125  AUTOMATIC_IR_METADATA(cgen_state_.get());
126  const auto window_func_context =
128  const auto window_func = window_func_context->getWindowFunction();
129  const auto arg_ti = get_adjusted_window_type_info(window_func);
130  llvm::Type* aggregate_state_type =
131  arg_ti.get_type() == kFLOAT
132  ? llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0)
133  : llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
134  const auto aggregate_state_i64 = cgen_state_->llInt(
135  reinterpret_cast<const int64_t>(window_func_context->aggregateState()));
136  return cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_i64,
137  aggregate_state_type);
138 }
139 
141  AUTOMATIC_IR_METADATA(cgen_state_.get());
142  const auto reset_state_false_bb = codegenWindowResetStateControlFlow();
143  auto aggregate_state = aggregateWindowStatePtr();
144  llvm::Value* aggregate_state_count = nullptr;
145  const auto window_func_context =
147  const auto window_func = window_func_context->getWindowFunction();
148  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
149  const auto aggregate_state_count_i64 = cgen_state_->llInt(
150  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
151  const auto pi64_type =
152  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
153  aggregate_state_count =
154  cgen_state_->ir_builder_.CreateIntToPtr(aggregate_state_count_i64, pi64_type);
155  }
156  codegenWindowFunctionStateInit(aggregate_state);
157  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
158  const auto count_zero = cgen_state_->llInt(int64_t(0));
159  cgen_state_->emitCall("agg_id", {aggregate_state_count, count_zero});
160  }
161  cgen_state_->ir_builder_.CreateBr(reset_state_false_bb);
162  cgen_state_->ir_builder_.SetInsertPoint(reset_state_false_bb);
164  return codegenWindowFunctionAggregateCalls(aggregate_state, co);
165 }
166 
168  AUTOMATIC_IR_METADATA(cgen_state_.get());
169  const auto window_func_context =
171  const auto bitset = cgen_state_->llInt(
172  reinterpret_cast<const int64_t>(window_func_context->partitionStart()));
173  const auto min_val = cgen_state_->llInt(int64_t(0));
174  const auto max_val = cgen_state_->llInt(window_func_context->elementCount() - 1);
175  const auto null_val = cgen_state_->llInt(inline_int_null_value<int64_t>());
176  const auto null_bool_val = cgen_state_->llInt<int8_t>(inline_int_null_value<int8_t>());
177  CodeGenerator code_generator(this);
178  const auto reset_state =
179  code_generator.toBool(cgen_state_->emitCall("bit_is_set",
180  {bitset,
181  code_generator.posArg(nullptr),
182  min_val,
183  max_val,
184  null_val,
185  null_bool_val}));
186  const auto reset_state_true_bb = llvm::BasicBlock::Create(
187  cgen_state_->context_, "reset_state.true", cgen_state_->current_func_);
188  const auto reset_state_false_bb = llvm::BasicBlock::Create(
189  cgen_state_->context_, "reset_state.false", cgen_state_->current_func_);
190  cgen_state_->ir_builder_.CreateCondBr(
191  reset_state, reset_state_true_bb, reset_state_false_bb);
192  cgen_state_->ir_builder_.SetInsertPoint(reset_state_true_bb);
193  return reset_state_false_bb;
194 }
195 
196 void Executor::codegenWindowFunctionStateInit(llvm::Value* aggregate_state) {
197  AUTOMATIC_IR_METADATA(cgen_state_.get());
198  const auto window_func_context =
200  const auto window_func = window_func_context->getWindowFunction();
201  const auto window_func_ti = get_adjusted_window_type_info(window_func);
202  const auto window_func_null_val =
203  window_func_ti.is_fp()
204  ? cgen_state_->inlineFpNull(window_func_ti)
205  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
206  llvm::Value* window_func_init_val;
207  if (window_func_context->getWindowFunction()->getKind() ==
209  switch (window_func_ti.get_type()) {
210  case kFLOAT: {
211  window_func_init_val = cgen_state_->llFp(float(0));
212  break;
213  }
214  case kDOUBLE: {
215  window_func_init_val = cgen_state_->llFp(double(0));
216  break;
217  }
218  default: {
219  window_func_init_val = cgen_state_->llInt(int64_t(0));
220  break;
221  }
222  }
223  } else {
224  window_func_init_val = window_func_null_val;
225  }
226  const auto pi32_type =
227  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
228  switch (window_func_ti.get_type()) {
229  case kDOUBLE: {
230  cgen_state_->emitCall("agg_id_double", {aggregate_state, window_func_init_val});
231  break;
232  }
233  case kFLOAT: {
234  aggregate_state =
235  cgen_state_->ir_builder_.CreateBitCast(aggregate_state, pi32_type);
236  cgen_state_->emitCall("agg_id_float", {aggregate_state, window_func_init_val});
237  break;
238  }
239  default: {
240  cgen_state_->emitCall("agg_id", {aggregate_state, window_func_init_val});
241  break;
242  }
243  }
244 }
245 
246 llvm::Value* Executor::codegenWindowFunctionAggregateCalls(llvm::Value* aggregate_state,
247  const CompilationOptions& co) {
248  AUTOMATIC_IR_METADATA(cgen_state_.get());
249  const auto window_func_context =
251  const auto window_func = window_func_context->getWindowFunction();
252  const auto window_func_ti = get_adjusted_window_type_info(window_func);
253  const auto window_func_null_val =
254  window_func_ti.is_fp()
255  ? cgen_state_->inlineFpNull(window_func_ti)
256  : cgen_state_->castToTypeIn(cgen_state_->inlineIntNull(window_func_ti), 64);
257  const auto& args = window_func->getArgs();
258  llvm::Value* crt_val;
259  if (args.empty()) {
260  CHECK(window_func->getKind() == SqlWindowFunctionKind::COUNT);
261  crt_val = cgen_state_->llInt(int64_t(1));
262  } else {
263  CodeGenerator code_generator(this);
264  const auto arg_lvs = code_generator.codegen(args.front().get(), true, co);
265  CHECK_EQ(arg_lvs.size(), size_t(1));
266  if (window_func->getKind() == SqlWindowFunctionKind::SUM && !window_func_ti.is_fp()) {
267  crt_val = code_generator.codegenCastBetweenIntTypes(
268  arg_lvs.front(), args.front()->get_type_info(), window_func_ti, false);
269  } else {
270  crt_val = window_func_ti.get_type() == kFLOAT
271  ? arg_lvs.front()
272  : cgen_state_->castToTypeIn(arg_lvs.front(), 64);
273  }
274  }
275  const auto agg_name = get_window_agg_name(window_func->getKind(), window_func_ti);
276  llvm::Value* multiplicity_lv = nullptr;
277  if (args.empty()) {
278  cgen_state_->emitCall(agg_name, {aggregate_state, crt_val});
279  } else {
280  cgen_state_->emitCall(agg_name + "_skip_val",
281  {aggregate_state, crt_val, window_func_null_val});
282  }
283  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
284  codegenWindowAvgEpilogue(crt_val, window_func_null_val, multiplicity_lv);
285  }
286  return codegenAggregateWindowState();
287 }
288 
289 void Executor::codegenWindowAvgEpilogue(llvm::Value* crt_val,
290  llvm::Value* window_func_null_val,
291  llvm::Value* multiplicity_lv) {
292  AUTOMATIC_IR_METADATA(cgen_state_.get());
293  const auto window_func_context =
295  const auto window_func = window_func_context->getWindowFunction();
296  const auto window_func_ti = get_adjusted_window_type_info(window_func);
297  const auto pi32_type =
298  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
299  const auto pi64_type =
300  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
301  const auto aggregate_state_type =
302  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
303  const auto aggregate_state_count_i64 = cgen_state_->llInt(
304  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
305  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
306  aggregate_state_count_i64, aggregate_state_type);
307  std::string agg_count_func_name = "agg_count";
308  switch (window_func_ti.get_type()) {
309  case kFLOAT: {
310  agg_count_func_name += "_float";
311  break;
312  }
313  case kDOUBLE: {
314  agg_count_func_name += "_double";
315  break;
316  }
317  default: {
318  break;
319  }
320  }
321  agg_count_func_name += "_skip_val";
322  cgen_state_->emitCall(agg_count_func_name,
323  {aggregate_state_count, crt_val, window_func_null_val});
324 }
325 
327  AUTOMATIC_IR_METADATA(cgen_state_.get());
328  const auto pi32_type =
329  llvm::PointerType::get(get_int_type(32, cgen_state_->context_), 0);
330  const auto pi64_type =
331  llvm::PointerType::get(get_int_type(64, cgen_state_->context_), 0);
332  const auto window_func_context =
334  const Analyzer::WindowFunction* window_func = window_func_context->getWindowFunction();
335  const auto window_func_ti = get_adjusted_window_type_info(window_func);
336  const auto aggregate_state_type =
337  window_func_ti.get_type() == kFLOAT ? pi32_type : pi64_type;
338  auto aggregate_state = aggregateWindowStatePtr();
339  if (window_func->getKind() == SqlWindowFunctionKind::AVG) {
340  const auto aggregate_state_count_i64 = cgen_state_->llInt(
341  reinterpret_cast<const int64_t>(window_func_context->aggregateStateCount()));
342  auto aggregate_state_count = cgen_state_->ir_builder_.CreateIntToPtr(
343  aggregate_state_count_i64, aggregate_state_type);
344  const auto double_null_lv = cgen_state_->inlineFpNull(SQLTypeInfo(kDOUBLE));
345  switch (window_func_ti.get_type()) {
346  case kFLOAT: {
347  return cgen_state_->emitCall(
348  "load_avg_float", {aggregate_state, aggregate_state_count, double_null_lv});
349  }
350  case kDOUBLE: {
351  return cgen_state_->emitCall(
352  "load_avg_double", {aggregate_state, aggregate_state_count, double_null_lv});
353  }
354  case kDECIMAL: {
355  return cgen_state_->emitCall(
356  "load_avg_decimal",
357  {aggregate_state,
358  aggregate_state_count,
359  double_null_lv,
360  cgen_state_->llInt<int32_t>(window_func_ti.get_scale())});
361  }
362  default: {
363  return cgen_state_->emitCall(
364  "load_avg_int", {aggregate_state, aggregate_state_count, double_null_lv});
365  }
366  }
367  }
368  if (window_func->getKind() == SqlWindowFunctionKind::COUNT) {
369  return cgen_state_->ir_builder_.CreateLoad(aggregate_state);
370  }
371  switch (window_func_ti.get_type()) {
372  case kFLOAT: {
373  return cgen_state_->emitCall("load_float", {aggregate_state});
374  }
375  case kDOUBLE: {
376  return cgen_state_->emitCall("load_double", {aggregate_state});
377  }
378  default: {
379  return cgen_state_->ir_builder_.CreateLoad(aggregate_state);
380  }
381  }
382 }
#define CHECK_EQ(x, y)
Definition: Logger.h:205
#define LOG(tag)
Definition: Logger.h:188
const WindowFunctionContext * activateWindowFunctionContext(Executor *executor, const size_t target_index) const
llvm::Value * aggregateWindowStatePtr()
llvm::Type * get_int_type(const int width, llvm::LLVMContext &context)
static WindowFunctionContext * getActiveWindowFunctionContext(Executor *executor)
std::string get_window_agg_name(const SqlWindowFunctionKind kind, const SQLTypeInfo &window_func_ti)
void codegenWindowAvgEpilogue(llvm::Value *crt_val, llvm::Value *window_func_null_val, llvm::Value *multiplicity_lv)
static const WindowProjectNodeContext * get(Executor *executor)
const Analyzer::WindowFunction * getWindowFunction() const
llvm::Value * codegenCastBetweenIntTypes(llvm::Value *operand_lv, const SQLTypeInfo &operand_ti, const SQLTypeInfo &ti, bool upscale=true)
Definition: CastIR.cpp:256
#define AUTOMATIC_IR_METADATA(CGENSTATE)
std::vector< llvm::Value * > codegen(const Analyzer::Expr *, const bool fetch_columns, const CompilationOptions &)
Definition: IRCodegen.cpp:26
void codegenWindowFunctionStateInit(llvm::Value *aggregate_state)
llvm::Value * toBool(llvm::Value *)
Definition: LogicalIR.cpp:335
SqlWindowFunctionKind
Definition: sqldefs.h:82
const std::vector< std::shared_ptr< Analyzer::Expr > > & getArgs() const
Definition: Analyzer.h:1449
llvm::Value * codegenAggregateWindowState()
const SQLTypeInfo & get_type_info() const
Definition: Analyzer.h:78
llvm::Value * codegenWindowFunctionAggregate(const CompilationOptions &co)
#define CHECK(condition)
Definition: Logger.h:197
HOST DEVICE SQLTypes get_type() const
Definition: sqltypes.h:259
llvm::Value * codegenWindowFunctionAggregateCalls(llvm::Value *aggregate_state, const CompilationOptions &co)
SqlWindowFunctionKind getKind() const
Definition: Analyzer.h:1447
llvm::Value * codegenWindowFunction(const size_t target_index, const CompilationOptions &co)
llvm::BasicBlock * codegenWindowResetStateControlFlow()
SQLTypeInfo get_adjusted_window_type_info(const Analyzer::WindowFunction *window_func)