OmniSciDB  ca0c39ec8f
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StringDictionaryProxy.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include "Logger/Logger.h"
20 #include "Shared/ThreadInfo.h"
21 #include "Shared/misc.h"
22 #include "Shared/sqltypes.h"
23 #include "Shared/thread_count.h"
25 #include "StringOps/StringOps.h"
26 #include "Utils/Regexp.h"
27 #include "Utils/StringLike.h"
28 
29 #include <tbb/parallel_for.h>
30 #include <tbb/task_arena.h>
31 
32 #include <algorithm>
33 #include <iomanip>
34 #include <iostream>
35 #include <string>
36 #include <string_view>
37 #include <thread>
38 
39 StringDictionaryProxy::StringDictionaryProxy(std::shared_ptr<StringDictionary> sd,
40  const int32_t string_dict_id,
41  const int64_t generation)
42  : string_dict_(sd), string_dict_id_(string_dict_id), generation_(generation) {}
43 
44 int32_t truncate_to_generation(const int32_t id, const size_t generation) {
46  return id;
47  }
48  CHECK_GE(id, 0);
49  return static_cast<size_t>(id) >= generation ? StringDictionary::INVALID_STR_ID : id;
50 }
51 
53  const std::vector<std::string>& strings) const {
55  std::vector<int32_t> string_ids(strings.size());
56  getTransientBulkImpl(strings, string_ids.data(), true);
57  return string_ids;
58 }
59 
61  const std::vector<std::string>& strings) {
63  const size_t num_strings = strings.size();
64  std::vector<int32_t> string_ids(num_strings);
65  if (num_strings == 0) {
66  return string_ids;
67  }
68  // Since new strings added to a StringDictionaryProxy are not materialized in the
69  // proxy's underlying StringDictionary, we can use the fast parallel
70  // StringDictionary::getBulk method to fetch ids from the underlying dictionary (which
71  // will return StringDictionary::INVALID_STR_ID for strings that don't exist)
72 
73  // Don't need to be under lock here as the string ids for strings in the underlying
74  // materialized dictionary are immutable
75  const size_t num_strings_not_found =
76  string_dict_->getBulk(strings, string_ids.data(), generation_);
77  if (num_strings_not_found > 0) {
78  std::lock_guard<std::shared_mutex> write_lock(rw_mutex_);
79  for (size_t string_idx = 0; string_idx < num_strings; ++string_idx) {
80  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
81  string_ids[string_idx] = getOrAddTransientUnlocked(strings[string_idx]);
82  }
83  }
84  }
85  return string_ids;
86 }
87 
88 template <typename String>
90  unsigned const new_index = transient_str_to_int_.size();
91  auto transient_id = transientIndexToId(new_index);
92  auto const emplaced = transient_str_to_int_.emplace(str, transient_id);
93  if (emplaced.second) { // (str, transient_id) was added to transient_str_to_int_.
94  transient_string_vec_.push_back(&emplaced.first->first);
95  } else { // str already exists in transient_str_to_int_. Return existing transient_id.
96  transient_id = emplaced.first->second;
97  }
98  return transient_id;
99 }
100 
101 int32_t StringDictionaryProxy::getOrAddTransient(const std::string& str) {
102  auto const string_id = getIdOfStringFromClient(str);
103  if (string_id != StringDictionary::INVALID_STR_ID) {
104  return string_id;
105  }
106  std::lock_guard<std::shared_mutex> write_lock(rw_mutex_);
107  return getOrAddTransientUnlocked(str);
108 }
109 
110 int32_t StringDictionaryProxy::getIdOfString(const std::string& str) const {
111  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
112  auto const str_id = getIdOfStringFromClient(str);
113  if (str_id != StringDictionary::INVALID_STR_ID || transient_str_to_int_.empty()) {
114  return str_id;
115  }
116  auto it = transient_str_to_int_.find(str);
117  return it != transient_str_to_int_.end() ? it->second
119 }
120 
121 template <typename String>
122 int32_t StringDictionaryProxy::getIdOfStringFromClient(const String& str) const {
123  CHECK_GE(generation_, 0);
124  return truncate_to_generation(string_dict_->getIdOfString(str), generation_);
125 }
126 
127 int32_t StringDictionaryProxy::getIdOfStringNoGeneration(const std::string& str) const {
128  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
129  auto str_id = string_dict_->getIdOfString(str);
130  if (str_id != StringDictionary::INVALID_STR_ID || transient_str_to_int_.empty()) {
131  return str_id;
132  }
133  auto it = transient_str_to_int_.find(str);
134  return it != transient_str_to_int_.end() ? it->second
136 }
137 
139  int8_t* proxy_ptr,
140  int32_t string_id) {
141  CHECK(proxy_ptr != nullptr);
142  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
143  auto [c_str, len] = proxy->getStringBytes(string_id);
144  return c_str;
145 }
146 
147 extern "C" DEVICE RUNTIME_EXPORT size_t
148 StringDictionaryProxy_getStringLength(int8_t* proxy_ptr, int32_t string_id) {
149  CHECK(proxy_ptr != nullptr);
150  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
151  auto [c_str, len] = proxy->getStringBytes(string_id);
152  return len;
153 }
154 
155 extern "C" DEVICE RUNTIME_EXPORT int32_t
156 StringDictionaryProxy_getStringId(int8_t* proxy_ptr, char* c_str_ptr) {
157  CHECK(proxy_ptr != nullptr);
158  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
159  std::string str(c_str_ptr);
160  return proxy->getOrAddTransient(str);
161 }
162 
163 std::string StringDictionaryProxy::getString(int32_t string_id) const {
164  if (inline_int_null_value<int32_t>() == string_id) {
165  return "";
166  }
167  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
168  return getStringUnlocked(string_id);
169 }
170 
171 std::string StringDictionaryProxy::getStringUnlocked(const int32_t string_id) const {
172  if (string_id >= 0 && storageEntryCount() > 0) {
173  return string_dict_->getString(string_id);
174  }
175  unsigned const string_index = transientIdToIndex(string_id);
176  CHECK_LT(string_index, transient_string_vec_.size());
177  return *transient_string_vec_[string_index];
178 }
179 
180 std::vector<std::string> StringDictionaryProxy::getStrings(
181  const std::vector<int32_t>& string_ids) const {
182  std::vector<std::string> strings;
183  if (!string_ids.empty()) {
184  strings.reserve(string_ids.size());
185  for (const auto string_id : string_ids) {
186  if (string_id >= 0) {
187  strings.emplace_back(string_dict_->getString(string_id));
188  } else if (inline_int_null_value<int32_t>() == string_id) {
189  strings.emplace_back("");
190  } else {
191  unsigned const string_index = transientIdToIndex(string_id);
192  strings.emplace_back(*transient_string_vec_[string_index]);
193  }
194  }
195  }
196  return strings;
197 }
198 
199 template <typename String>
201  const String& lookup_string) const {
202  const auto it = transient_str_to_int_.find(lookup_string);
204  : it->second;
205 }
206 
209  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
210  auto timer = DEBUG_TIMER(__func__);
211  CHECK(string_op_infos.size());
213  if (translation_map.empty()) {
214  return translation_map;
215  }
216 
217  const StringOps_Namespace::StringOps string_ops(string_op_infos);
218 
219  const size_t num_transient_entries = translation_map.numTransients();
220  if (num_transient_entries) {
221  const int32_t map_domain_start = translation_map.domainStart();
222  if (num_transient_entries > 10000UL) {
224  tbb::blocked_range<int32_t>(map_domain_start, -1),
225  [&](const tbb::blocked_range<int32_t>& r) {
226  const int32_t start_idx = r.begin();
227  const int32_t end_idx = r.end();
228  for (int32_t source_string_id = start_idx; source_string_id < end_idx;
229  ++source_string_id) {
230  const auto source_string = getStringUnlocked(source_string_id);
231  translation_map[source_string_id] = string_ops.numericEval(source_string);
232  }
233  });
234  } else {
235  for (int32_t source_string_id = map_domain_start; source_string_id < -1;
236  ++source_string_id) {
237  const auto source_string = getStringUnlocked(source_string_id);
238  translation_map[source_string_id] = string_ops.numericEval(source_string);
239  }
240  }
241  }
242 
243  Datum* translation_map_stored_entries_ptr = translation_map.storageData();
244  if (generation_ > 0) {
245  string_dict_->buildDictionaryNumericTranslationMap(
246  translation_map_stored_entries_ptr, generation_, string_op_infos);
247  }
248  translation_map.setNumUntranslatedStrings(0UL);
249 
250  // Todo(todd): Set range start/end with scan
251 
252  return translation_map;
253 }
254 
257  const StringDictionaryProxy* dest_proxy,
258  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
259  auto timer = DEBUG_TIMER(__func__);
260  IdMap id_map = initIdMap();
261 
262  if (id_map.empty()) {
263  return id_map;
264  }
265 
266  const StringOps_Namespace::StringOps string_ops(string_op_infos);
267 
268  // First map transient strings, store at front of vector map
269  const size_t num_transient_entries = id_map.numTransients();
270  size_t num_transient_strings_not_translated = 0UL;
271  if (num_transient_entries) {
272  std::vector<std::string> transient_lookup_strings(num_transient_entries);
273  if (string_ops.size()) {
275  transient_string_vec_.cend(),
276  transient_lookup_strings.rbegin(),
277  [&](std::string const* ptr) { return string_ops(*ptr); });
278  } else {
280  transient_string_vec_.cend(),
281  transient_lookup_strings.rbegin(),
282  [](std::string const* ptr) { return *ptr; });
283  }
284 
285  // This lookup may have a different snapshot of
286  // dest_proxy transients and dictionary than what happens under
287  // the below dest_proxy_read_lock. We may need an unlocked version of
288  // getTransientBulk to ensure consistency (I don't believe
289  // current behavior would cause crashes/races, verify this though)
290 
291  // Todo(mattp): Consider variant of getTransientBulkImp that can take
292  // a vector of pointer-to-strings so we don't have to materialize
293  // transient_string_vec_ into transient_lookup_strings.
294 
295  num_transient_strings_not_translated =
296  dest_proxy->getTransientBulkImpl(transient_lookup_strings, id_map.data(), false);
297  }
298 
299  // Now map strings in dictionary
300  // We place non-transient strings after the transient strings
301  // if they exist, otherwise at index 0
302  int32_t* translation_map_stored_entries_ptr = id_map.storageData();
303 
304  auto dest_transient_lookup_callback = [dest_proxy, translation_map_stored_entries_ptr](
305  const std::string_view& source_string,
306  const int32_t source_string_id) {
307  translation_map_stored_entries_ptr[source_string_id] =
308  dest_proxy->lookupTransientStringUnlocked(source_string);
309  return translation_map_stored_entries_ptr[source_string_id] ==
311  };
312 
313  const size_t num_dest_transients = dest_proxy->transientEntryCountUnlocked();
314  const size_t num_persisted_strings_not_translated =
315  generation_ > 0 ? string_dict_->buildDictionaryTranslationMap(
316  dest_proxy->string_dict_.get(),
317  translation_map_stored_entries_ptr,
318  generation_,
319  dest_proxy->generation_,
320  num_dest_transients > 0UL,
321  dest_transient_lookup_callback,
322  string_op_infos)
323  : 0UL;
324 
325  const size_t num_dest_entries = dest_proxy->entryCountUnlocked();
326  const size_t num_total_entries =
327  id_map.getVectorMap().size() - 1UL /* account for skipped entry -1 */;
328  CHECK_GT(num_total_entries, 0UL);
329  const size_t num_strings_not_translated =
330  num_transient_strings_not_translated + num_persisted_strings_not_translated;
331  CHECK_LE(num_strings_not_translated, num_total_entries);
332  id_map.setNumUntranslatedStrings(num_strings_not_translated);
333 
334  // Below is a conservative setting of range based on the size of the destination proxy,
335  // but probably not worth a scan over the data (or inline computation as we translate)
336  // to compute the actual ranges
337 
338  id_map.setRangeStart(
339  num_dest_transients > 0 ? -1 - static_cast<int32_t>(num_dest_transients) : 0);
340  id_map.setRangeEnd(dest_proxy->storageEntryCount());
341 
342  const size_t num_entries_translated = num_total_entries - num_strings_not_translated;
343  const float match_pct =
344  100.0 * static_cast<float>(num_entries_translated) / num_total_entries;
345  VLOG(1) << std::fixed << std::setprecision(2) << match_pct << "% ("
346  << num_entries_translated << " entries) from dictionary ("
347  << string_dict_->getDbId() << ", " << string_dict_->getDictId() << ") with "
348  << num_total_entries << " total entries ( " << num_transient_entries
349  << " literals)"
350  << " translated to dictionary (" << dest_proxy->string_dict_->getDbId() << ", "
351  << dest_proxy->string_dict_->getDictId() << ") with " << num_dest_entries
352  << " total entries (" << dest_proxy->transientEntryCountUnlocked()
353  << " literals).";
354 
355  return id_map;
356 }
357 
358 void order_translation_locks(const int32_t source_dict_id,
359  const int32_t dest_dict_id,
360  std::shared_lock<std::shared_mutex>& source_proxy_read_lock,
361  std::unique_lock<std::shared_mutex>& dest_proxy_write_lock) {
362  if (source_dict_id == dest_dict_id) {
363  // proxies are same, only take one write lock
364  dest_proxy_write_lock.lock();
365  } else if (source_dict_id < dest_dict_id) {
366  source_proxy_read_lock.lock();
367  dest_proxy_write_lock.lock();
368  } else {
369  dest_proxy_write_lock.lock();
370  source_proxy_read_lock.lock();
371  }
372 }
373 
376  const StringDictionaryProxy* dest_proxy,
377  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
378  const auto source_dict_id = getDictId();
379  const auto dest_dict_id = dest_proxy->getDictId();
380 
381  std::shared_lock<std::shared_mutex> source_proxy_read_lock(rw_mutex_, std::defer_lock);
382  std::unique_lock<std::shared_mutex> dest_proxy_write_lock(dest_proxy->rw_mutex_,
383  std::defer_lock);
385  source_dict_id, dest_dict_id, source_proxy_read_lock, dest_proxy_write_lock);
386  return buildIntersectionTranslationMapToOtherProxyUnlocked(dest_proxy, string_op_infos);
387 }
388 
390  StringDictionaryProxy* dest_proxy,
391  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
392  auto timer = DEBUG_TIMER(__func__);
393 
394  const auto source_dict_id = getDictId();
395  const auto dest_dict_id = dest_proxy->getDictId();
396  std::shared_lock<std::shared_mutex> source_proxy_read_lock(rw_mutex_, std::defer_lock);
397  std::unique_lock<std::shared_mutex> dest_proxy_write_lock(dest_proxy->rw_mutex_,
398  std::defer_lock);
400  source_dict_id, dest_dict_id, source_proxy_read_lock, dest_proxy_write_lock);
401 
402  auto id_map =
403  buildIntersectionTranslationMapToOtherProxyUnlocked(dest_proxy, string_op_infos);
404  if (id_map.empty()) {
405  return id_map;
406  }
407  const auto num_untranslated_strings = id_map.numUntranslatedStrings();
408  if (num_untranslated_strings > 0) {
409  const size_t total_post_translation_dest_transients =
410  num_untranslated_strings + dest_proxy->transientEntryCountUnlocked();
411  constexpr size_t max_allowed_transients =
412  static_cast<size_t>(std::numeric_limits<int32_t>::max() -
413  2); /* -2 accounts for INVALID_STR_ID and NULL value */
414  if (total_post_translation_dest_transients > max_allowed_transients) {
415  throw std::runtime_error("Union translation to dictionary" +
416  std::to_string(getDictId()) + " would result in " +
417  std::to_string(total_post_translation_dest_transients) +
418  " transient entries, which is more than limit of " +
419  std::to_string(max_allowed_transients) + " transients.");
420  }
421  const int32_t map_domain_start = id_map.domainStart();
422  const int32_t map_domain_end = id_map.domainEnd();
423 
424  const StringOps_Namespace::StringOps string_ops(string_op_infos);
425  const bool has_string_ops = string_ops.size();
426 
427  // First iterate over transient strings and add to dest map
428  // Todo (todd): Add call to fetch string_views (local) or strings (distributed)
429  // for all non-translated ids to avoid string-by-string fetch
430 
431  for (int32_t source_string_id = map_domain_start; source_string_id < -1;
432  ++source_string_id) {
433  if (id_map[source_string_id] == StringDictionary::INVALID_STR_ID) {
434  const auto source_string = getStringUnlocked(source_string_id);
435  const auto dest_string_id = dest_proxy->getOrAddTransientUnlocked(
436  has_string_ops ? string_ops(source_string) : source_string);
437  id_map[source_string_id] = dest_string_id;
438  }
439  }
440  // Now iterate over stored strings
441  for (int32_t source_string_id = 0; source_string_id < map_domain_end;
442  ++source_string_id) {
443  if (id_map[source_string_id] == StringDictionary::INVALID_STR_ID) {
444  const auto source_string = string_dict_->getString(source_string_id);
445  const auto dest_string_id = dest_proxy->getOrAddTransientUnlocked(
446  has_string_ops ? string_ops(source_string) : source_string);
447  id_map[source_string_id] = dest_string_id;
448  }
449  }
450  }
451  // We may have added transients to the destination proxy, use this to update
452  // our id map range (used downstream for ExpressionRange)
453 
454  const size_t num_dest_transients = dest_proxy->transientEntryCountUnlocked();
455  id_map.setRangeStart(
456  num_dest_transients > 0 ? -1 - static_cast<int32_t>(num_dest_transients) : 0);
457  return id_map;
458 }
459 
460 namespace {
461 
462 bool is_like(const std::string& str,
463  const std::string& pattern,
464  const bool icase,
465  const bool is_simple,
466  const char escape) {
467  return icase
468  ? (is_simple ? string_ilike_simple(
469  str.c_str(), str.size(), pattern.c_str(), pattern.size())
470  : string_ilike(str.c_str(),
471  str.size(),
472  pattern.c_str(),
473  pattern.size(),
474  escape))
475  : (is_simple ? string_like_simple(
476  str.c_str(), str.size(), pattern.c_str(), pattern.size())
477  : string_like(str.c_str(),
478  str.size(),
479  pattern.c_str(),
480  pattern.size(),
481  escape));
482 }
483 
484 } // namespace
485 
486 std::vector<int32_t> StringDictionaryProxy::getLike(const std::string& pattern,
487  const bool icase,
488  const bool is_simple,
489  const char escape) const {
490  CHECK_GE(generation_, 0);
491  auto result = string_dict_->getLike(pattern, icase, is_simple, escape, generation_);
492  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
493  if (is_like(*transient_string_vec_[index], pattern, icase, is_simple, escape)) {
494  result.push_back(transientIndexToId(index));
495  }
496  }
497  return result;
498 }
499 
500 namespace {
501 
502 bool do_compare(const std::string& str,
503  const std::string& pattern,
504  const std::string& comp_operator) {
505  int res = str.compare(pattern);
506  if (comp_operator == "<") {
507  return res < 0;
508  } else if (comp_operator == "<=") {
509  return res <= 0;
510  } else if (comp_operator == "=") {
511  return res == 0;
512  } else if (comp_operator == ">") {
513  return res > 0;
514  } else if (comp_operator == ">=") {
515  return res >= 0;
516  } else if (comp_operator == "<>") {
517  return res != 0;
518  }
519  throw std::runtime_error("unsupported string compare operator");
520 }
521 
522 } // namespace
523 
525  const std::string& pattern,
526  const std::string& comp_operator) const {
527  CHECK_GE(generation_, 0);
528  auto result = string_dict_->getCompare(pattern, comp_operator, generation_);
529  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
530  if (do_compare(*transient_string_vec_[index], pattern, comp_operator)) {
531  result.push_back(transientIndexToId(index));
532  }
533  }
534  return result;
535 }
536 
537 namespace {
538 
539 bool is_regexp_like(const std::string& str,
540  const std::string& pattern,
541  const char escape) {
542  return regexp_like(str.c_str(), str.size(), pattern.c_str(), pattern.size(), escape);
543 }
544 
545 } // namespace
546 
547 std::vector<int32_t> StringDictionaryProxy::getRegexpLike(const std::string& pattern,
548  const char escape) const {
549  CHECK_GE(generation_, 0);
550  auto result = string_dict_->getRegexpLike(pattern, escape, generation_);
551  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
552  if (is_regexp_like(*transient_string_vec_[index], pattern, escape)) {
553  result.push_back(transientIndexToId(index));
554  }
555  }
556  return result;
557 }
558 
559 int32_t StringDictionaryProxy::getOrAdd(const std::string& str) noexcept {
560  return string_dict_->getOrAdd(str);
561 }
562 
563 std::pair<const char*, size_t> StringDictionaryProxy::getStringBytes(
564  int32_t string_id) const noexcept {
565  if (string_id >= 0) {
566  return string_dict_.get()->getStringBytes(string_id);
567  }
568  unsigned const string_index = transientIdToIndex(string_id);
569  CHECK_LT(string_index, transient_string_vec_.size());
570  std::string const* const str_ptr = transient_string_vec_[string_index];
571  return {str_ptr->c_str(), str_ptr->size()};
572 }
573 
575  const size_t num_storage_entries{generation_ == -1 ? string_dict_->storageEntryCount()
576  : generation_};
577  CHECK_LE(num_storage_entries, static_cast<size_t>(std::numeric_limits<int32_t>::max()));
578  return num_storage_entries;
579 }
580 
582  // CHECK_LE(num_storage_entries,
583  // static_cast<size_t>(std::numeric_limits<int32_t>::max()));
584  const size_t num_transient_entries{transient_str_to_int_.size()};
585  CHECK_LE(num_transient_entries,
586  static_cast<size_t>(std::numeric_limits<int32_t>::max()) - 1);
587  return num_transient_entries;
588 }
589 
591  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
593 }
594 
597 }
598 
600  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
601  return entryCountUnlocked();
602 }
603 
604 // Iterate over transient strings, then non-transients.
606  StringDictionary::StringCallback& serial_callback) const {
607  constexpr int32_t max_transient_id = -2;
608  // Iterate over transient strings.
609  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
610  std::string const& str = *transient_string_vec_[index];
611  int32_t const string_id = max_transient_id - index;
612  serial_callback(str, string_id);
613  }
614  // Iterate over non-transient strings.
615  string_dict_->eachStringSerially(generation_, serial_callback);
616 }
617 
618 // For each (string/_view,old_id) pair passed in:
619 // * Get the new_id based on sdp_'s dictionary, or add it as a transient.
620 // * The StringDictionary is local, so call the faster getUnlocked() method.
621 // * Store the old_id -> new_id translation into the id_map_.
625 
626  public:
628  : sdp_(sdp), id_map_(id_map) {}
629  void operator()(std::string const& str, int32_t const string_id) override {
630  operator()(std::string_view(str), string_id);
631  }
632  void operator()(std::string_view const sv, int32_t const old_id) override {
633  int32_t const new_id = sdp_->string_dict_->getUnlocked(sv);
634  id_map_[old_id] = new_id == StringDictionary::INVALID_STR_ID
636  : new_id;
637  }
638 };
639 
640 // For each (string,old_id) pair passed in:
641 // * Get the new_id based on sdp_'s dictionary, or add it as a transient.
642 // * The StringDictionary is not local, so call string_dict_->makeLambdaStringToId()
643 // to make a lookup hash.
644 // * Store the old_id -> new_id translation into the id_map_.
648  using Lambda = std::function<int32_t(std::string const&)>;
650 
651  public:
653  : sdp_(sdp)
654  , id_map_(id_map)
655  , string_to_id_(sdp->string_dict_->makeLambdaStringToId()) {}
656  void operator()(std::string const& str, int32_t const old_id) override {
657  int32_t const new_id = string_to_id_(str);
658  id_map_[old_id] = new_id == StringDictionary::INVALID_STR_ID
660  : new_id;
661  }
662  void operator()(std::string_view const, int32_t const string_id) override {
663  UNREACHABLE() << "StringNetworkCallback requires a std::string.";
664  }
665 };
666 
667 // Union strings from both StringDictionaryProxies into *this as transients.
668 // Return id_map: sdp_rhs:string_id -> this:string_id for each string in sdp_rhs.
670  StringDictionaryProxy const& sdp_rhs) {
671  IdMap id_map = sdp_rhs.initIdMap();
672  // serial_callback cannot be parallelized due to calling getOrAddTransientUnlocked().
673  std::unique_ptr<StringDictionary::StringCallback> serial_callback;
674  if (string_dict_->isClient()) {
675  serial_callback = std::make_unique<StringNetworkCallback>(this, id_map);
676  } else {
677  serial_callback = std::make_unique<StringLocalCallback>(this, id_map);
678  }
679  // Import all non-duplicate strings (transient and non-transient) and add to id_map.
680  sdp_rhs.eachStringSerially(*serial_callback);
681  return id_map;
682 }
683 
684 void StringDictionaryProxy::updateGeneration(const int64_t generation) noexcept {
685  if (generation == -1) {
686  return;
687  }
688  if (generation_ != -1) {
689  CHECK_EQ(generation_, generation);
690  return;
691  }
692  generation_ = generation;
693 }
694 
696  const std::vector<std::string>& strings,
697  int32_t* string_ids,
698  const bool take_read_lock) const {
699  const size_t num_strings = strings.size();
700  if (num_strings == 0) {
701  return 0UL;
702  }
703  // StringDictionary::getBulk returns the number of strings not found
704  if (string_dict_->getBulk(strings, string_ids, generation_) == 0UL) {
705  return 0UL;
706  }
707 
708  // If here, dictionary could not find at least 1 target string,
709  // now look these up in the transient dictionary
710  // transientLookupBulk returns the number of strings not found
711  return transientLookupBulk(strings, string_ids, take_read_lock);
712 }
713 
714 template <typename String>
716  const std::vector<String>& lookup_strings,
717  int32_t* string_ids,
718  const bool take_read_lock) const {
719  const size_t num_strings = lookup_strings.size();
720  auto read_lock = take_read_lock ? std::shared_lock<std::shared_mutex>(rw_mutex_)
721  : std::shared_lock<std::shared_mutex>();
722 
723  if (num_strings == static_cast<size_t>(0) || transient_str_to_int_.empty()) {
724  return 0UL;
725  }
726  constexpr size_t tbb_parallel_threshold{20000};
727  if (num_strings < tbb_parallel_threshold) {
728  return transientLookupBulkUnlocked(lookup_strings, string_ids);
729  } else {
730  return transientLookupBulkParallelUnlocked(lookup_strings, string_ids);
731  }
732 }
733 
734 template <typename String>
736  const std::vector<String>& lookup_strings,
737  int32_t* string_ids) const {
738  const size_t num_strings = lookup_strings.size();
739  size_t num_strings_not_found = 0;
740  for (size_t string_idx = 0; string_idx < num_strings; ++string_idx) {
741  if (string_ids[string_idx] != StringDictionary::INVALID_STR_ID) {
742  continue;
743  }
744  // If we're here it means we need to look up this string as we don't
745  // have a valid id for it
746  string_ids[string_idx] = lookupTransientStringUnlocked(lookup_strings[string_idx]);
747  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
748  num_strings_not_found++;
749  }
750  }
751  return num_strings_not_found;
752 }
753 
754 template <typename String>
756  const std::vector<String>& lookup_strings,
757  int32_t* string_ids) const {
758  const size_t num_lookup_strings = lookup_strings.size();
759  const size_t target_inputs_per_thread = 20000L;
760  ThreadInfo thread_info(
761  std::thread::hardware_concurrency(), num_lookup_strings, target_inputs_per_thread);
762  CHECK_GE(thread_info.num_threads, 1L);
763  CHECK_GE(thread_info.num_elems_per_thread, 1L);
764 
765  std::vector<size_t> num_strings_not_found_per_thread(thread_info.num_threads, 0UL);
766 
767  tbb::task_arena limited_arena(thread_info.num_threads);
768  limited_arena.execute([&] {
770  tbb::blocked_range<size_t>(
771  0, num_lookup_strings, thread_info.num_elems_per_thread /* tbb grain_size */),
772  [&](const tbb::blocked_range<size_t>& r) {
773  const size_t start_idx = r.begin();
774  const size_t end_idx = r.end();
775  size_t num_local_strings_not_found = 0;
776  for (size_t string_idx = start_idx; string_idx < end_idx; ++string_idx) {
777  if (string_ids[string_idx] != StringDictionary::INVALID_STR_ID) {
778  continue;
779  }
780  string_ids[string_idx] =
781  lookupTransientStringUnlocked(lookup_strings[string_idx]);
782  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
783  num_local_strings_not_found++;
784  }
785  }
786  const size_t tbb_thread_idx = tbb::this_task_arena::current_thread_index();
787  num_strings_not_found_per_thread[tbb_thread_idx] = num_local_strings_not_found;
788  },
789  tbb::simple_partitioner());
790  });
791  size_t num_strings_not_found = 0;
792  for (int64_t thread_idx = 0; thread_idx < thread_info.num_threads; ++thread_idx) {
793  num_strings_not_found += num_strings_not_found_per_thread[thread_idx];
794  }
795  return num_strings_not_found;
796 }
797 
799  return string_dict_.get();
800 }
801 
802 int64_t StringDictionaryProxy::getGeneration() const noexcept {
803  return generation_;
804 }
805 
807  return string_dict_id_ == rhs.string_dict_id_ &&
809 }
810 
812  return !operator==(rhs);
813 }
void eachStringSerially(StringDictionary::StringCallback &) const
void setNumUntranslatedStrings(const size_t num_untranslated_strings)
#define CHECK_EQ(x, y)
Definition: Logger.h:230
std::pair< const char *, size_t > getStringBytes(int32_t string_id) const noexcept
std::vector< int32_t > getLike(const std::string &pattern, const bool icase, const bool is_simple, const char escape) const
size_t transientEntryCountUnlocked() const
StringLocalCallback(StringDictionaryProxy *sdp, StringDictionaryProxy::IdMap &id_map)
int64_t num_elems_per_thread
Definition: ThreadInfo.h:23
StringDictionaryProxy::IdMap & id_map_
heavyai::shared_lock< heavyai::shared_mutex > read_lock
size_t entryCount() const
Returns the number of total string entries for this proxy, both stored in the underlying dictionary a...
int32_t getIdOfStringNoGeneration(const std::string &str) const
std::function< int32_t(std::string const &)> Lambda
std::string getStringUnlocked(const int32_t string_id) const
size_t storageEntryCount() const
Returns the number of string entries in the underlying string dictionary, at this proxy&#39;s generation_...
#define UNREACHABLE()
Definition: Logger.h:266
StringDictionary * getDictionary() const noexcept
#define CHECK_GE(x, y)
Definition: Logger.h:235
size_t transientLookupBulkUnlocked(const std::vector< String > &lookup_strings, int32_t *string_ids) const
StringDictionaryProxy * sdp_
void operator()(std::string const &str, int32_t const string_id) override
size_t transientLookupBulk(const std::vector< String > &lookup_strings, int32_t *string_ids, const bool take_read_lock) const
std::string getString(int32_t string_id) const
Constants for Builtin SQL Types supported by HEAVY.AI.
heavyai::unique_lock< heavyai::shared_mutex > write_lock
IdMap buildIntersectionTranslationMapToOtherProxyUnlocked(const StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
size_t transientLookupBulkParallelUnlocked(const std::vector< String > &lookup_strings, int32_t *string_ids) const
#define CHECK_GT(x, y)
Definition: Logger.h:234
int32_t getIdOfStringFromClient(String const &) const
std::vector< int32_t > getTransientBulk(const std::vector< std::string > &strings) const
Executes read-only lookup of a vector of strings and returns a vector of their integer ids...
std::string to_string(char const *&&v)
TranslationMap< Datum > buildNumericTranslationMap(const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
Builds a vectorized string_id translation map from this proxy to dest_proxy.
std::vector< int32_t > getCompare(const std::string &pattern, const std::string &comp_operator) const
#define DEVICE
bool is_regexp_like(const std::string &str, const std::string &pattern, const char escape)
StringNetworkCallback(StringDictionaryProxy *sdp, StringDictionaryProxy::IdMap &id_map)
static constexpr int32_t INVALID_STR_ID
std::shared_ptr< StringDictionary > string_dict_
int64_t num_threads
Definition: ThreadInfo.h:22
IdMap transientUnion(StringDictionaryProxy const &)
std::vector< std::string const * > transient_string_vec_
void setRangeEnd(const int32_t range_end)
void order_translation_locks(const int32_t source_db_id, const int32_t source_dict_id, const int32_t dest_db_id, const int32_t dest_dict_id, std::shared_lock< std::shared_mutex > &source_read_lock, std::shared_lock< std::shared_mutex > &dest_read_lock)
RUNTIME_EXPORT DEVICE bool string_like(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: StringLike.cpp:244
void operator()(std::string const &str, int32_t const old_id) override
int32_t lookupTransientStringUnlocked(const String &lookup_string) const
std::vector< std::string > getStrings(const std::vector< int32_t > &string_ids) const
size_t getTransientBulkImpl(const std::vector< std::string > &strings, int32_t *string_ids, const bool take_read_lock) const
RUNTIME_EXPORT DEVICE bool string_like_simple(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len)
Definition: StringLike.cpp:41
bool is_like(const std::string &str, const std::string &pattern, const bool icase, const bool is_simple, const char escape)
void operator()(std::string_view const sv, int32_t const old_id) override
static int32_t transientIndexToId(unsigned const index)
void updateGeneration(const int64_t generation) noexcept
size_t transientEntryCount() const
Returns the number of transient string entries for this proxy,.
OUTPUT transform(INPUT const &input, FUNC const &func)
Definition: misc.h:296
Functions to support the LIKE and ILIKE operator in SQL. Only single-byte character set is supported ...
IdMap buildUnionTranslationMapToOtherProxy(StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_types) const
StringDictionaryProxy(StringDictionaryProxy const &)=delete
void setRangeStart(const int32_t range_start)
#define RUNTIME_EXPORT
#define CHECK_LT(x, y)
Definition: Logger.h:232
void operator()(std::string_view const, int32_t const string_id) override
bool do_compare(const std::string &str, const std::string &pattern, const std::string &comp_operator)
#define CHECK_LE(x, y)
Definition: Logger.h:233
StringDictionaryProxy * sdp_
int32_t getOrAddTransientUnlocked(String const &)
bool operator!=(StringDictionaryProxy const &) const
std::vector< int32_t > getRegexpLike(const std::string &pattern, const char escape) const
int32_t getOrAdd(const std::string &str) noexcept
RUNTIME_EXPORT DEVICE bool string_ilike_simple(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len)
Definition: StringLike.cpp:57
bool operator==(StringDictionaryProxy const &) const
std::vector< T > const & getVectorMap() const
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
int32_t getDictId() const noexcept
std::vector< int32_t > getOrAddTransientBulk(const std::vector< std::string > &strings)
IdMap buildIntersectionTranslationMapToOtherProxy(const StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
#define CHECK(condition)
Definition: Logger.h:222
DEVICE RUNTIME_EXPORT int32_t StringDictionaryProxy_getStringId(int8_t *proxy_ptr, char *c_str_ptr)
#define DEBUG_TIMER(name)
Definition: Logger.h:371
int32_t getOrAddTransient(const std::string &str)
DEVICE RUNTIME_EXPORT size_t StringDictionaryProxy_getStringLength(int8_t *proxy_ptr, int32_t string_id)
Definition: Datum.h:44
RUNTIME_EXPORT DEVICE bool regexp_like(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: Regexp.cpp:39
int32_t getIdOfString(const std::string &str) const
static unsigned transientIdToIndex(int32_t const id)
int64_t getGeneration() const noexcept
#define VLOG(n)
Definition: Logger.h:316
int32_t truncate_to_generation(const int32_t id, const size_t generation)
DEVICE RUNTIME_EXPORT const char * StringDictionaryProxy_getStringBytes(int8_t *proxy_ptr, int32_t string_id)
StringDictionaryProxy::IdMap & id_map_
RUNTIME_EXPORT DEVICE bool string_ilike(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: StringLike.cpp:255