OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
SqliteConnector.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
23 #include "SqliteConnector.h"
24 
25 #include <iostream>
26 
27 #include "Logger/Logger.h"
28 
29 using std::cout;
30 using std::endl;
31 using std::runtime_error;
32 using std::string;
33 
34 SqliteConnector::SqliteConnector(const string& dbName, const string& dir)
35  : dbName_(dbName) {
36  string connectString(dir);
37  if (connectString.size() > 0 && connectString[connectString.size() - 1] != '/') {
38  connectString.push_back('/');
39  }
40  connectString += dbName;
41  int returnCode = sqlite3_open(connectString.c_str(), &db_);
42  if (returnCode != SQLITE_OK) {
43  throwError();
44  }
45 }
46 
47 SqliteConnector::SqliteConnector(sqlite3* db) : db_(db) {}
48 
50  if (!dbName_.empty()) {
51  sqlite3_close(db_);
52  }
53 }
54 
56  string errorMsg(sqlite3_errmsg(db_));
57  throw runtime_error("Sqlite3 Error: " + errorMsg);
58 }
59 
60 std::string get_column_datum(int column_type, sqlite3_stmt* stmt, size_t column_index) {
61  const char* datum_ptr;
62  if (column_type == SQLITE_BLOB) {
63  datum_ptr = static_cast<const char*>(sqlite3_column_blob(stmt, column_index));
64  } else {
65  datum_ptr = reinterpret_cast<const char*>(sqlite3_column_text(stmt, column_index));
66  }
67  size_t datum_size = sqlite3_column_bytes(stmt, column_index);
68  return {datum_ptr, datum_size};
69 }
70 
71 void SqliteConnector::query_with_text_params(const std::string& queryString,
72  const std::vector<std::string>& text_params,
73  const std::vector<BindType>& bind_types) {
74  if (!bind_types.empty()) {
75  CHECK_EQ(text_params.size(), bind_types.size());
76  }
77 
78  atFirstResult_ = true;
79  numRows_ = 0;
80  numCols_ = 0;
81  columnNames.clear();
82  columnTypes.clear();
83  results_.clear();
84  sqlite3_stmt* stmt;
85  int returnCode = sqlite3_prepare_v2(db_, queryString.c_str(), -1, &stmt, nullptr);
86  if (returnCode != SQLITE_OK) {
87  throwError();
88  }
89 
90  int num_params = 1;
91  for (auto text_param : text_params) {
92  if (!bind_types.empty() && bind_types[num_params - 1] == BindType::BLOB) {
93  returnCode = sqlite3_bind_blob(
94  stmt, num_params++, text_param.c_str(), text_param.size(), SQLITE_TRANSIENT);
95  } else if (!bind_types.empty() && bind_types[num_params - 1] == BindType::NULL_TYPE) {
96  returnCode = sqlite3_bind_null(stmt, num_params++);
97  } else {
98  returnCode = sqlite3_bind_text(
99  stmt, num_params++, text_param.c_str(), text_param.size(), SQLITE_TRANSIENT);
100  }
101  if (returnCode != SQLITE_OK) {
102  throwError();
103  }
104  }
105 
106  do {
107  returnCode = sqlite3_step(stmt);
108  if (returnCode != SQLITE_ROW && returnCode != SQLITE_DONE) {
109  throwError();
110  }
111  if (returnCode == SQLITE_DONE) {
112  break;
113  }
114  if (atFirstResult_) {
115  numCols_ = sqlite3_column_count(stmt);
116  for (size_t c = 0; c < numCols_; ++c) {
117  columnNames.emplace_back(sqlite3_column_name(stmt, c));
118  columnTypes.push_back(sqlite3_column_type(stmt, c));
119  }
120  results_.resize(numCols_);
121  atFirstResult_ = false;
122  }
123  numRows_++;
124  for (size_t c = 0; c < numCols_; ++c) {
125  auto column_type = sqlite3_column_type(stmt, c);
126  bool is_null = (column_type == SQLITE_NULL);
127  auto col_text = get_column_datum(column_type, stmt, c);
128  if (is_null) {
129  CHECK(col_text.empty());
130  }
131  results_[c].emplace_back(NullableResult{col_text, is_null});
132  }
133  } while (1 == 1); // Loop control in break statement above
134 
135  sqlite3_finalize(stmt);
136 }
137 
139  const std::string& queryString,
140  const std::vector<std::string>& text_params) {
141  query_with_text_params(queryString, text_params, {});
142 }
143 
144 void SqliteConnector::query_with_text_param(const std::string& queryString,
145  const std::string& text_param) {
146  query_with_text_params(queryString, std::vector<std::string>{text_param});
147 }
148 
149 void SqliteConnector::query(const std::string& queryString) {
150  query_with_text_params(queryString, std::vector<std::string>{});
151 }
152 
153 void SqliteConnector::batch_insert(const std::string& table_name,
154  std::vector<std::vector<std::string>>& insert_vals) {
155  const size_t num_rows = insert_vals.size();
156  if (!num_rows) {
157  return;
158  }
159  const size_t num_cols(insert_vals[0].size());
160  if (!num_cols) {
161  return;
162  }
163  std::string paramertized_query = "INSERT INTO " + table_name + " VALUES(";
164  for (size_t col_idx = 0; col_idx < num_cols - 1; ++col_idx) {
165  paramertized_query += "?, ";
166  }
167  paramertized_query += "?)";
168 
169  query("BEGIN TRANSACTION");
170 
171  sqlite3_stmt* stmt;
172  int returnCode =
173  sqlite3_prepare_v2(db_, paramertized_query.c_str(), -1, &stmt, nullptr);
174  if (returnCode != SQLITE_OK) {
175  throwError();
176  }
177 
178  for (size_t r = 0; r < num_rows; ++r) {
179  const auto& row_insert_vals = insert_vals[r];
180  int num_params = 1;
181  for (const auto& insert_field : row_insert_vals) {
182  returnCode = sqlite3_bind_text(stmt,
183  num_params++,
184  insert_field.c_str(),
185  insert_field.size(),
186  SQLITE_TRANSIENT);
187  if (returnCode != SQLITE_OK) {
188  throwError();
189  }
190  }
191  returnCode = sqlite3_step(stmt);
192  if (returnCode != SQLITE_DONE) {
193  throwError();
194  }
195  sqlite3_reset(stmt);
196  }
197  sqlite3_finalize(stmt);
198  query("END TRANSACTION");
199 }
#define CHECK_EQ(x, y)
Definition: Logger.h:301
virtual void query_with_text_params(std::string const &query_only)
virtual void batch_insert(const std::string &table_name, std::vector< std::vector< std::string >> &insert_vals)
virtual void query(const std::string &queryString)
std::string get_column_datum(int column_type, sqlite3_stmt *stmt, size_t column_index)
CONSTEXPR DEVICE bool is_null(const T &value)
std::vector< std::vector< NullableResult > > results_
std::vector< int > columnTypes
std::string dbName_
virtual ~SqliteConnector()
#define CHECK(condition)
Definition: Logger.h:291
std::vector< std::string > columnNames
virtual void query_with_text_param(const std::string &queryString, const std::string &text_param)