OmniSciDB  5ade3759e0
JoinLoopTest.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 MapD Technologies, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Build with `make -f JoinLoopTestMakefile`, compare the output
18 // with the one generated by the `generate_loop_ref.py` script.
19 
20 #include "JoinLoop.h"
21 #include "Shared/Logger.h"
22 
23 #include <llvm/ExecutionEngine/MCJIT.h>
24 #include <llvm/IR/BasicBlock.h>
25 #include <llvm/IR/Function.h>
26 #include <llvm/IR/IRBuilder.h>
27 #include <llvm/IR/Module.h>
28 #include <llvm/IR/Type.h>
29 #include <llvm/IR/Verifier.h>
30 #include <llvm/Support/TargetSelect.h>
31 #include <llvm/Support/raw_os_ostream.h>
32 
33 #include <memory>
34 #include <vector>
35 
36 extern "C" void print_iterators(const int64_t i, const int64_t j, const int64_t k) {
37  printf("%ld, %ld, %ld\n", i, j, k);
38 }
39 
40 namespace {
41 
42 llvm::LLVMContext g_global_context;
43 
44 void verify_function_ir(const llvm::Function* func) {
45  std::stringstream err_ss;
46  llvm::raw_os_ostream err_os(err_ss);
47  if (llvm::verifyFunction(*func, &err_os)) {
48  func->print(llvm::outs());
49  LOG(FATAL) << err_ss.str();
50  }
51 }
52 
53 llvm::Value* emit_external_call(const std::string& fname,
54  llvm::Type* ret_type,
55  const std::vector<llvm::Value*> args,
56  llvm::Module* module,
57  llvm::IRBuilder<>& builder) {
58  std::vector<llvm::Type*> arg_types;
59  for (const auto arg : args) {
60  arg_types.push_back(arg->getType());
61  }
62  auto func_ty = llvm::FunctionType::get(ret_type, arg_types, false);
63  auto func_p = module->getOrInsertFunction(fname, func_ty);
64  CHECK(func_p);
65  llvm::Value* result = builder.CreateCall(func_p, args);
66  // check the assumed type
67  CHECK_EQ(result->getType(), ret_type);
68  return result;
69 }
70 
71 llvm::Function* create_loop_test_function(llvm::LLVMContext& context,
72  llvm::Module* module,
73  const std::vector<JoinLoop>& join_loops) {
74  std::vector<llvm::Type*> argument_types;
75  const auto ft =
76  llvm::FunctionType::get(llvm::Type::getVoidTy(context), argument_types, false);
77  const auto func = llvm::Function::Create(
78  ft, llvm::Function::ExternalLinkage, "loop_test_func", module);
79  const auto entry_bb = llvm::BasicBlock::Create(context, "entry", func);
80  const auto exit_bb = llvm::BasicBlock::Create(context, "exit", func);
81  llvm::IRBuilder<> builder(context);
82  builder.SetInsertPoint(exit_bb);
83  builder.CreateRetVoid();
84  const auto loop_body_bb = JoinLoop::codegen(
85  join_loops,
86  [&builder, module](const std::vector<llvm::Value*>& iterators) {
87  const auto loop_body_bb = llvm::BasicBlock::Create(
88  builder.getContext(), "loop_body", builder.GetInsertBlock()->getParent());
89  builder.SetInsertPoint(loop_body_bb);
90  const std::vector<llvm::Value*> args(iterators.begin() + 1, iterators.end());
91  emit_external_call("print_iterators",
92  llvm::Type::getVoidTy(builder.getContext()),
93  args,
94  module,
95  builder);
96  return loop_body_bb;
97  },
98  nullptr,
99  exit_bb,
100  builder);
101  builder.SetInsertPoint(entry_bb);
102  builder.CreateBr(loop_body_bb);
103  verify_function_ir(func);
104  return func;
105 }
106 
107 std::unique_ptr<llvm::Module> create_loop_test_module() {
108  return llvm::make_unique<llvm::Module>("Nested loops JIT", g_global_context);
109 }
110 
111 std::pair<void*, std::unique_ptr<llvm::ExecutionEngine>> native_codegen(
112  std::unique_ptr<llvm::Module>& module,
113  llvm::Function* func) {
114  llvm::ExecutionEngine* execution_engine{nullptr};
115 
116  auto init_err = llvm::InitializeNativeTarget();
117  CHECK(!init_err);
118 
119  llvm::InitializeAllTargetMCs();
120  llvm::InitializeNativeTargetAsmPrinter();
121  llvm::InitializeNativeTargetAsmParser();
122 
123  std::string err_str;
124  llvm::EngineBuilder eb(std::move(module));
125  eb.setErrorStr(&err_str);
126  eb.setEngineKind(llvm::EngineKind::JIT);
127  llvm::TargetOptions to;
128  to.EnableFastISel = true;
129  eb.setTargetOptions(to);
130  execution_engine = eb.create();
131  CHECK(execution_engine);
132 
133  execution_engine->finalizeObject();
134  auto native_code = execution_engine->getPointerToFunction(func);
135 
136  CHECK(native_code);
137  return {native_code, std::unique_ptr<llvm::ExecutionEngine>(execution_engine)};
138 }
139 
140 std::vector<JoinLoop> generate_descriptors(const unsigned mask,
141  const unsigned cond_mask,
142  const std::vector<int64_t>& upper_bounds) {
143  std::vector<JoinLoop> join_loops;
144  size_t cond_idx{0};
145  for (size_t i = 0; i < upper_bounds.size(); ++i) {
146  if (mask & (1 << i)) {
147  const bool cond_is_true = cond_mask & (1 << cond_idx);
148  join_loops.emplace_back(
151  [i, cond_is_true](const std::vector<llvm::Value*>& v) {
152  CHECK_EQ(i + 1, v.size());
153  CHECK(!v.front());
154  JoinLoopDomain domain{{0}};
155  domain.slot_lookup_result = cond_is_true
156  ? ll_int(int64_t(99), g_global_context)
157  : ll_int(int64_t(-1), g_global_context);
158  return domain;
159  },
160  nullptr,
161  nullptr,
162  nullptr,
163  "i" + std::to_string(i));
164  ++cond_idx;
165  } else {
166  const auto upper_bound = upper_bounds[i];
167  join_loops.emplace_back(
170  [i, upper_bound](const std::vector<llvm::Value*>& v) {
171  CHECK_EQ(i + 1, v.size());
172  CHECK(!v.front());
173  JoinLoopDomain domain{{0}};
174  domain.upper_bound = ll_int<int64_t>(upper_bound, g_global_context);
175  return domain;
176  },
177  nullptr,
178  nullptr,
179  nullptr,
180  "i" + std::to_string(i));
181  }
182  }
183  return join_loops;
184 }
185 
186 } // namespace
187 
188 int main() {
189  std::vector<int64_t> upper_bounds{5, 3, 9};
190  for (unsigned mask = 0; mask < static_cast<unsigned>(1 << upper_bounds.size());
191  ++mask) {
192  const unsigned mask_bitcount = __builtin_popcount(mask);
193  for (unsigned cond_mask = 0; cond_mask < static_cast<unsigned>(1 << mask_bitcount);
194  ++cond_mask) {
195  auto module = create_loop_test_module();
196  const auto join_loops = generate_descriptors(mask, cond_mask, upper_bounds);
197  const auto function =
198  create_loop_test_function(g_global_context, module.get(), join_loops);
199  const auto& func_and_ee = native_codegen(module, function);
200  reinterpret_cast<int64_t (*)()>(func_and_ee.first)();
201  }
202  }
203  return 0;
204 }
#define CHECK_EQ(x, y)
Definition: Logger.h:195
llvm::Value * emit_external_call(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value *> args, llvm::Module *module, llvm::IRBuilder<> &builder)
#define LOG(tag)
Definition: Logger.h:182
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
void print_iterators(const int64_t i, const int64_t j, const int64_t k)
llvm::Function * create_loop_test_function(llvm::LLVMContext &context, llvm::Module *module, const std::vector< JoinLoop > &join_loops)
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value *> &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, llvm::IRBuilder<> &builder)
Definition: JoinLoop.cpp:45
std::string to_string(char const *&&v)
int main()
T v(const TargetValue &r)
llvm::Value * slot_lookup_result
Definition: JoinLoop.h:45
llvm::Value * upper_bound
Definition: JoinLoop.h:43
std::unique_ptr< llvm::Module > create_loop_test_module()
void verify_function_ir(const llvm::Function *func)
std::vector< JoinLoop > generate_descriptors(const unsigned mask, const unsigned cond_mask, const std::vector< int64_t > &upper_bounds)
#define CHECK(condition)
Definition: Logger.h:187
std::pair< void *, std::unique_ptr< llvm::ExecutionEngine > > native_codegen(std::unique_ptr< llvm::Module > &module, llvm::Function *func)