OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
JoinLoopTest.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Build with `make -f JoinLoopTestMakefile`, compare the output
18 // with the one generated by the `generate_loop_ref.py` script.
19 
20 #include "JoinLoop.h"
21 #include "Logger/Logger.h"
22 
23 #include <llvm/ExecutionEngine/MCJIT.h>
24 #include <llvm/IR/BasicBlock.h>
25 #include <llvm/IR/Function.h>
26 #include <llvm/IR/IRBuilder.h>
27 #include <llvm/IR/Module.h>
28 #include <llvm/IR/Type.h>
29 #include <llvm/IR/Verifier.h>
30 #include <llvm/Support/TargetSelect.h>
31 #include <llvm/Support/raw_os_ostream.h>
32 
33 #include <memory>
34 #include <vector>
35 
36 extern "C" RUNTIME_EXPORT void print_iterators(const int64_t i,
37  const int64_t j,
38  const int64_t k) {
39  printf("%ld, %ld, %ld\n", i, j, k);
40 }
41 
42 namespace {
43 
44 llvm::LLVMContext g_global_context;
45 
46 void verify_function_ir(const llvm::Function* func) {
47  std::stringstream err_ss;
48  llvm::raw_os_ostream err_os(err_ss);
49  if (llvm::verifyFunction(*func, &err_os)) {
50  func->print(llvm::outs());
51  LOG(FATAL) << err_ss.str();
52  }
53 }
54 
55 llvm::Value* emit_external_call(const std::string& fname,
56  llvm::Type* ret_type,
57  const std::vector<llvm::Value*> args,
58  llvm::Module* llvm_module,
59  llvm::IRBuilder<>& builder) {
60  std::vector<llvm::Type*> arg_types;
61  for (const auto arg : args) {
62  arg_types.push_back(arg->getType());
63  }
64  auto func_ty = llvm::FunctionType::get(ret_type, arg_types, false);
65  auto func_p = llvm_module->getOrInsertFunction(fname, func_ty);
66  CHECK(func_p);
67  llvm::Value* result = builder.CreateCall(func_p, args);
68  // check the assumed type
69  CHECK_EQ(result->getType(), ret_type);
70  return result;
71 }
72 
73 llvm::Function* create_loop_test_function(llvm::LLVMContext& context,
74  llvm::Module* llvm_module,
75  const std::vector<JoinLoop>& join_loops) {
76  std::vector<llvm::Type*> argument_types;
77  const auto ft =
78  llvm::FunctionType::get(llvm::Type::getVoidTy(context), argument_types, false);
79  const auto func = llvm::Function::Create(
80  ft, llvm::Function::ExternalLinkage, "loop_test_func", llvm_module);
81  const auto entry_bb = llvm::BasicBlock::Create(context, "entry", func);
82  const auto exit_bb = llvm::BasicBlock::Create(context, "exit", func);
83  llvm::IRBuilder<> builder(context);
84  builder.SetInsertPoint(exit_bb);
85  builder.CreateRetVoid();
86  const auto loop_body_bb = JoinLoop::codegen(
87  join_loops,
88  [&builder, llvm_module](const std::vector<llvm::Value*>& iterators) {
89  const auto loop_body_bb = llvm::BasicBlock::Create(
90  builder.getContext(), "loop_body", builder.GetInsertBlock()->getParent());
91  builder.SetInsertPoint(loop_body_bb);
92  const std::vector<llvm::Value*> args(iterators.begin() + 1, iterators.end());
93  emit_external_call("print_iterators",
94  llvm::Type::getVoidTy(builder.getContext()),
95  args,
96  llvm_module,
97  builder);
98  return loop_body_bb;
99  },
100  nullptr,
101  exit_bb,
102  builder);
103  builder.SetInsertPoint(entry_bb);
104  builder.CreateBr(loop_body_bb);
105  verify_function_ir(func);
106  return func;
107 }
108 
109 std::unique_ptr<llvm::Module> create_loop_test_module() {
110  return std::make_unique<llvm::Module>("Nested loops JIT", g_global_context);
111 }
112 
113 std::pair<void*, std::unique_ptr<llvm::ExecutionEngine>> native_codegen(
114  std::unique_ptr<llvm::Module>& llvm_module,
115  llvm::Function* func) {
116  llvm::ExecutionEngine* execution_engine{nullptr};
117 
118  auto init_err = llvm::InitializeNativeTarget();
119  CHECK(!init_err);
120 
121  llvm::InitializeAllTargetMCs();
122  llvm::InitializeNativeTargetAsmPrinter();
123  llvm::InitializeNativeTargetAsmParser();
124 
125  std::string err_str;
126  llvm::EngineBuilder eb(std::move(llvm_module));
127  eb.setErrorStr(&err_str);
128  eb.setEngineKind(llvm::EngineKind::JIT);
129  llvm::TargetOptions to;
130  to.EnableFastISel = true;
131  eb.setTargetOptions(to);
132  execution_engine = eb.create();
133  CHECK(execution_engine);
134 
135  execution_engine->finalizeObject();
136  auto native_code = execution_engine->getPointerToFunction(func);
137 
138  CHECK(native_code);
139  return {native_code, std::unique_ptr<llvm::ExecutionEngine>(execution_engine)};
140 }
141 
142 std::vector<JoinLoop> generate_descriptors(const unsigned mask,
143  const unsigned cond_mask,
144  const std::vector<int64_t>& upper_bounds) {
145  std::vector<JoinLoop> join_loops;
146  size_t cond_idx{0};
147  for (size_t i = 0; i < upper_bounds.size(); ++i) {
148  if (mask & (1 << i)) {
149  const bool cond_is_true = cond_mask & (1 << cond_idx);
150  join_loops.emplace_back(
153  [i, cond_is_true](const std::vector<llvm::Value*>& v) {
154  CHECK_EQ(i + 1, v.size());
155  CHECK(!v.front());
156  JoinLoopDomain domain{{0}};
157  domain.slot_lookup_result = cond_is_true
158  ? ll_int(int64_t(99), g_global_context)
159  : ll_int(int64_t(-1), g_global_context);
160  return domain;
161  },
162  nullptr,
163  nullptr,
164  nullptr,
165  false,
166  "i" + std::to_string(i));
167  ++cond_idx;
168  } else {
169  const auto upper_bound = upper_bounds[i];
170  join_loops.emplace_back(
173  [i, upper_bound](const std::vector<llvm::Value*>& v) {
174  CHECK_EQ(i + 1, v.size());
175  CHECK(!v.front());
176  JoinLoopDomain domain{{0}};
177  domain.upper_bound = ll_int<int64_t>(upper_bound, g_global_context);
178  return domain;
179  },
180  nullptr,
181  nullptr,
182  nullptr,
183  false,
184  "i" + std::to_string(i));
185  }
186  }
187  return join_loops;
188 }
189 
190 } // namespace
191 
192 int main() {
193  std::vector<int64_t> upper_bounds{5, 3, 9};
194  for (unsigned mask = 0; mask < static_cast<unsigned>(1 << upper_bounds.size());
195  ++mask) {
196  const unsigned mask_bitcount = __builtin_popcount(mask);
197  for (unsigned cond_mask = 0; cond_mask < static_cast<unsigned>(1 << mask_bitcount);
198  ++cond_mask) {
199  auto llvm_module = create_loop_test_module();
200  const auto join_loops = generate_descriptors(mask, cond_mask, upper_bounds);
201  const auto function =
202  create_loop_test_function(g_global_context, llvm_module.get(), join_loops);
203  const auto& func_and_ee = native_codegen(llvm_module, function);
204  reinterpret_cast<int64_t (*)()>(func_and_ee.first)();
205  }
206  }
207  return 0;
208 }
DEVICE auto upper_bound(ARGS &&...args)
Definition: gpu_enabled.h:123
#define CHECK_EQ(x, y)
Definition: Logger.h:301
#define LOG(tag)
Definition: Logger.h:285
llvm::ConstantInt * ll_int(const T v, llvm::LLVMContext &context)
std::string to_string(char const *&&v)
void verify_function_ir(const llvm::Function *func)
llvm::Value * emit_external_call(const std::string &fname, llvm::Type *ret_type, const std::vector< llvm::Value * > args, llvm::Module *llvm_module, llvm::IRBuilder<> &builder)
static llvm::BasicBlock * codegen(const std::vector< JoinLoop > &join_loops, const std::function< llvm::BasicBlock *(const std::vector< llvm::Value * > &)> &body_codegen, llvm::Value *outer_iter, llvm::BasicBlock *exit_bb, CgenState *cgen_state)
Definition: JoinLoop.cpp:50
RUNTIME_EXPORT void print_iterators(const int64_t i, const int64_t j, const int64_t k)
std::pair< void *, std::unique_ptr< llvm::ExecutionEngine > > native_codegen(std::unique_ptr< llvm::Module > &llvm_module, llvm::Function *func)
llvm::Value * slot_lookup_result
Definition: JoinLoop.h:47
#define RUNTIME_EXPORT
llvm::ManagedStatic< llvm::LLVMContext > g_global_context
llvm::Value * upper_bound
Definition: JoinLoop.h:45
std::unique_ptr< llvm::Module > create_loop_test_module()
std::vector< JoinLoop > generate_descriptors(const unsigned mask, const unsigned cond_mask, const std::vector< int64_t > &upper_bounds)
#define CHECK(condition)
Definition: Logger.h:291
llvm::Function * create_loop_test_function(llvm::LLVMContext &context, llvm::Module *llvm_module, const std::vector< JoinLoop > &join_loops)