OmniSciDB  a987f07e93
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StringDictionaryProxy.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2022 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include "Logger/Logger.h"
20 #include "Shared/ThreadInfo.h"
21 #include "Shared/misc.h"
22 #include "Shared/sqltypes.h"
23 #include "Shared/thread_count.h"
25 #include "StringOps/StringOps.h"
26 #include "Utils/Regexp.h"
27 #include "Utils/StringLike.h"
28 
29 #include <tbb/parallel_for.h>
30 #include <tbb/task_arena.h>
31 
32 #include <algorithm>
33 #include <iomanip>
34 #include <iostream>
35 #include <string>
36 #include <string_view>
37 #include <thread>
38 
39 StringDictionaryProxy::StringDictionaryProxy(std::shared_ptr<StringDictionary> sd,
40  const int32_t string_dict_id,
41  const int64_t generation)
42  : string_dict_(sd), string_dict_id_(string_dict_id), generation_(generation) {}
43 
44 int32_t truncate_to_generation(const int32_t id, const size_t generation) {
46  return id;
47  }
48  CHECK_GE(id, 0);
49  return static_cast<size_t>(id) >= generation ? StringDictionary::INVALID_STR_ID : id;
50 }
51 
53  const std::vector<std::string>& strings) const {
55  std::vector<int32_t> string_ids(strings.size());
56  getTransientBulkImpl(strings, string_ids.data(), true);
57  return string_ids;
58 }
59 
61  const std::vector<std::string>& strings) {
63  const size_t num_strings = strings.size();
64  std::vector<int32_t> string_ids(num_strings);
65  if (num_strings == 0) {
66  return string_ids;
67  }
68  // Since new strings added to a StringDictionaryProxy are not materialized in the
69  // proxy's underlying StringDictionary, we can use the fast parallel
70  // StringDictionary::getBulk method to fetch ids from the underlying dictionary (which
71  // will return StringDictionary::INVALID_STR_ID for strings that don't exist)
72 
73  // Don't need to be under lock here as the string ids for strings in the underlying
74  // materialized dictionary are immutable
75  const size_t num_strings_not_found =
76  string_dict_->getBulk(strings, string_ids.data(), generation_);
77  if (num_strings_not_found > 0) {
78  std::lock_guard<std::shared_mutex> write_lock(rw_mutex_);
79  for (size_t string_idx = 0; string_idx < num_strings; ++string_idx) {
80  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
81  string_ids[string_idx] = getOrAddTransientUnlocked(strings[string_idx]);
82  }
83  }
84  }
85  return string_ids;
86 }
87 
88 template <typename String>
90  unsigned const new_index = transient_str_to_int_.size();
91  auto transient_id = transientIndexToId(new_index);
92  auto const emplaced = transient_str_to_int_.emplace(str, transient_id);
93  if (emplaced.second) { // (str, transient_id) was added to transient_str_to_int_.
94  transient_string_vec_.push_back(&emplaced.first->first);
95  } else { // str already exists in transient_str_to_int_. Return existing transient_id.
96  transient_id = emplaced.first->second;
97  }
98  return transient_id;
99 }
100 
101 template <typename String>
103  auto const string_id = getIdOfStringFromClient(str);
104  if (string_id != StringDictionary::INVALID_STR_ID) {
105  return string_id;
106  }
107  std::lock_guard<std::shared_mutex> write_lock(rw_mutex_);
108  return getOrAddTransientUnlocked(str);
109 }
110 
111 int32_t StringDictionaryProxy::getOrAddTransient(std::string const& str) {
112  return getOrAddTransientImpl<std::string const&>(str);
113 }
114 
115 int32_t StringDictionaryProxy::getOrAddTransient(std::string_view const sv) {
116  return getOrAddTransientImpl<std::string_view const>(sv);
117 }
118 
119 int32_t StringDictionaryProxy::getIdOfString(const std::string& str) const {
120  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
121  auto const str_id = getIdOfStringFromClient(str);
122  if (str_id != StringDictionary::INVALID_STR_ID || transient_str_to_int_.empty()) {
123  return str_id;
124  }
125  auto it = transient_str_to_int_.find(str);
126  return it != transient_str_to_int_.end() ? it->second
128 }
129 
130 template <typename String>
131 int32_t StringDictionaryProxy::getIdOfStringFromClient(const String& str) const {
132  CHECK_GE(generation_, 0);
133  return truncate_to_generation(string_dict_->getIdOfString(str), generation_);
134 }
135 
136 int32_t StringDictionaryProxy::getIdOfStringNoGeneration(const std::string& str) const {
137  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
138  auto str_id = string_dict_->getIdOfString(str);
139  if (str_id != StringDictionary::INVALID_STR_ID || transient_str_to_int_.empty()) {
140  return str_id;
141  }
142  auto it = transient_str_to_int_.find(str);
143  return it != transient_str_to_int_.end() ? it->second
145 }
146 
148  int8_t* proxy_ptr,
149  int32_t string_id) {
150  CHECK(proxy_ptr != nullptr);
151  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
152  auto [c_str, len] = proxy->getStringBytes(string_id);
153  return c_str;
154 }
155 
156 extern "C" DEVICE RUNTIME_EXPORT size_t
157 StringDictionaryProxy_getStringLength(int8_t* proxy_ptr, int32_t string_id) {
158  CHECK(proxy_ptr != nullptr);
159  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
160  auto [c_str, len] = proxy->getStringBytes(string_id);
161  return len;
162 }
163 
164 extern "C" DEVICE RUNTIME_EXPORT int32_t
165 StringDictionaryProxy_getStringId(int8_t* proxy_ptr, char* c_str_ptr) {
166  CHECK(proxy_ptr != nullptr);
167  auto proxy = reinterpret_cast<StringDictionaryProxy*>(proxy_ptr);
168  std::string str(c_str_ptr);
169  return proxy->getOrAddTransient(str);
170 }
171 
172 std::string StringDictionaryProxy::getString(int32_t string_id) const {
173  if (inline_int_null_value<int32_t>() == string_id) {
174  return "";
175  }
176  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
177  return getStringUnlocked(string_id);
178 }
179 
180 std::string StringDictionaryProxy::getStringUnlocked(const int32_t string_id) const {
181  if (string_id >= 0 && storageEntryCount() > 0) {
182  return string_dict_->getString(string_id);
183  }
184  unsigned const string_index = transientIdToIndex(string_id);
185  CHECK_LT(string_index, transient_string_vec_.size());
186  return *transient_string_vec_[string_index];
187 }
188 
189 std::vector<std::string> StringDictionaryProxy::getStrings(
190  const std::vector<int32_t>& string_ids) const {
191  std::vector<std::string> strings;
192  if (!string_ids.empty()) {
193  strings.reserve(string_ids.size());
194  for (const auto string_id : string_ids) {
195  if (string_id >= 0) {
196  strings.emplace_back(string_dict_->getString(string_id));
197  } else if (inline_int_null_value<int32_t>() == string_id) {
198  strings.emplace_back("");
199  } else {
200  unsigned const string_index = transientIdToIndex(string_id);
201  strings.emplace_back(*transient_string_vec_[string_index]);
202  }
203  }
204  }
205  return strings;
206 }
207 
208 template <typename String>
210  const String& lookup_string) const {
211  const auto it = transient_str_to_int_.find(lookup_string);
213  : it->second;
214 }
215 
218  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
219  auto timer = DEBUG_TIMER(__func__);
220  CHECK(string_op_infos.size());
222  if (translation_map.empty()) {
223  return translation_map;
224  }
225 
226  const StringOps_Namespace::StringOps string_ops(string_op_infos);
227 
228  const size_t num_transient_entries = translation_map.numTransients();
229  if (num_transient_entries) {
230  const int32_t map_domain_start = translation_map.domainStart();
231  if (num_transient_entries > 10000UL) {
233  tbb::blocked_range<int32_t>(map_domain_start, -1),
234  [&](const tbb::blocked_range<int32_t>& r) {
235  const int32_t start_idx = r.begin();
236  const int32_t end_idx = r.end();
237  for (int32_t source_string_id = start_idx; source_string_id < end_idx;
238  ++source_string_id) {
239  const auto source_string = getStringUnlocked(source_string_id);
240  translation_map[source_string_id] = string_ops.numericEval(source_string);
241  }
242  });
243  } else {
244  for (int32_t source_string_id = map_domain_start; source_string_id < -1;
245  ++source_string_id) {
246  const auto source_string = getStringUnlocked(source_string_id);
247  translation_map[source_string_id] = string_ops.numericEval(source_string);
248  }
249  }
250  }
251 
252  Datum* translation_map_stored_entries_ptr = translation_map.storageData();
253  if (generation_ > 0) {
254  string_dict_->buildDictionaryNumericTranslationMap(
255  translation_map_stored_entries_ptr, generation_, string_op_infos);
256  }
257  translation_map.setNumUntranslatedStrings(0UL);
258 
259  // Todo(todd): Set range start/end with scan
260 
261  return translation_map;
262 }
263 
266  const StringDictionaryProxy* dest_proxy,
267  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
268  auto timer = DEBUG_TIMER(__func__);
269  IdMap id_map = initIdMap();
270 
271  if (id_map.empty()) {
272  return id_map;
273  }
274 
275  const StringOps_Namespace::StringOps string_ops(string_op_infos);
276 
277  // First map transient strings, store at front of vector map
278  const size_t num_transient_entries = id_map.numTransients();
279  size_t num_transient_strings_not_translated = 0UL;
280  if (num_transient_entries) {
281  std::vector<std::string> transient_lookup_strings(num_transient_entries);
282  if (string_ops.size()) {
284  transient_string_vec_.cend(),
285  transient_lookup_strings.rbegin(),
286  [&](std::string const* ptr) { return string_ops(*ptr); });
287  } else {
289  transient_string_vec_.cend(),
290  transient_lookup_strings.rbegin(),
291  [](std::string const* ptr) { return *ptr; });
292  }
293 
294  // This lookup may have a different snapshot of
295  // dest_proxy transients and dictionary than what happens under
296  // the below dest_proxy_read_lock. We may need an unlocked version of
297  // getTransientBulk to ensure consistency (I don't believe
298  // current behavior would cause crashes/races, verify this though)
299 
300  // Todo(mattp): Consider variant of getTransientBulkImp that can take
301  // a vector of pointer-to-strings so we don't have to materialize
302  // transient_string_vec_ into transient_lookup_strings.
303 
304  num_transient_strings_not_translated =
305  dest_proxy->getTransientBulkImpl(transient_lookup_strings, id_map.data(), false);
306  }
307 
308  // Now map strings in dictionary
309  // We place non-transient strings after the transient strings
310  // if they exist, otherwise at index 0
311  int32_t* translation_map_stored_entries_ptr = id_map.storageData();
312 
313  auto dest_transient_lookup_callback = [dest_proxy, translation_map_stored_entries_ptr](
314  const std::string_view& source_string,
315  const int32_t source_string_id) {
316  translation_map_stored_entries_ptr[source_string_id] =
317  dest_proxy->lookupTransientStringUnlocked(source_string);
318  return translation_map_stored_entries_ptr[source_string_id] ==
320  };
321 
322  const size_t num_dest_transients = dest_proxy->transientEntryCountUnlocked();
323  const size_t num_persisted_strings_not_translated =
324  generation_ > 0 ? string_dict_->buildDictionaryTranslationMap(
325  dest_proxy->string_dict_.get(),
326  translation_map_stored_entries_ptr,
327  generation_,
328  dest_proxy->generation_,
329  num_dest_transients > 0UL,
330  dest_transient_lookup_callback,
331  string_op_infos)
332  : 0UL;
333 
334  const size_t num_dest_entries = dest_proxy->entryCountUnlocked();
335  const size_t num_total_entries =
336  id_map.getVectorMap().size() - 1UL /* account for skipped entry -1 */;
337  CHECK_GT(num_total_entries, 0UL);
338  const size_t num_strings_not_translated =
339  num_transient_strings_not_translated + num_persisted_strings_not_translated;
340  CHECK_LE(num_strings_not_translated, num_total_entries);
341  id_map.setNumUntranslatedStrings(num_strings_not_translated);
342 
343  // Below is a conservative setting of range based on the size of the destination proxy,
344  // but probably not worth a scan over the data (or inline computation as we translate)
345  // to compute the actual ranges
346 
347  id_map.setRangeStart(
348  num_dest_transients > 0 ? -1 - static_cast<int32_t>(num_dest_transients) : 0);
349  id_map.setRangeEnd(dest_proxy->storageEntryCount());
350 
351  const size_t num_entries_translated = num_total_entries - num_strings_not_translated;
352  const float match_pct =
353  100.0 * static_cast<float>(num_entries_translated) / num_total_entries;
354  VLOG(1) << std::fixed << std::setprecision(2) << match_pct << "% ("
355  << num_entries_translated << " entries) from dictionary ("
356  << string_dict_->getDbId() << ", " << string_dict_->getDictId() << ") with "
357  << num_total_entries << " total entries ( " << num_transient_entries
358  << " literals)"
359  << " translated to dictionary (" << dest_proxy->string_dict_->getDbId() << ", "
360  << dest_proxy->string_dict_->getDictId() << ") with " << num_dest_entries
361  << " total entries (" << dest_proxy->transientEntryCountUnlocked()
362  << " literals).";
363 
364  return id_map;
365 }
366 
367 void order_translation_locks(const int32_t source_dict_id,
368  const int32_t dest_dict_id,
369  std::shared_lock<std::shared_mutex>& source_proxy_read_lock,
370  std::unique_lock<std::shared_mutex>& dest_proxy_write_lock) {
371  if (source_dict_id == dest_dict_id) {
372  // proxies are same, only take one write lock
373  dest_proxy_write_lock.lock();
374  } else if (source_dict_id < dest_dict_id) {
375  source_proxy_read_lock.lock();
376  dest_proxy_write_lock.lock();
377  } else {
378  dest_proxy_write_lock.lock();
379  source_proxy_read_lock.lock();
380  }
381 }
382 
385  const StringDictionaryProxy* dest_proxy,
386  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
387  const auto source_dict_id = getDictId();
388  const auto dest_dict_id = dest_proxy->getDictId();
389 
390  std::shared_lock<std::shared_mutex> source_proxy_read_lock(rw_mutex_, std::defer_lock);
391  std::unique_lock<std::shared_mutex> dest_proxy_write_lock(dest_proxy->rw_mutex_,
392  std::defer_lock);
394  source_dict_id, dest_dict_id, source_proxy_read_lock, dest_proxy_write_lock);
395  return buildIntersectionTranslationMapToOtherProxyUnlocked(dest_proxy, string_op_infos);
396 }
397 
399  StringDictionaryProxy* dest_proxy,
400  const std::vector<StringOps_Namespace::StringOpInfo>& string_op_infos) const {
401  auto timer = DEBUG_TIMER(__func__);
402 
403  const auto source_dict_id = getDictId();
404  const auto dest_dict_id = dest_proxy->getDictId();
405  std::shared_lock<std::shared_mutex> source_proxy_read_lock(rw_mutex_, std::defer_lock);
406  std::unique_lock<std::shared_mutex> dest_proxy_write_lock(dest_proxy->rw_mutex_,
407  std::defer_lock);
409  source_dict_id, dest_dict_id, source_proxy_read_lock, dest_proxy_write_lock);
410 
411  auto id_map =
412  buildIntersectionTranslationMapToOtherProxyUnlocked(dest_proxy, string_op_infos);
413  if (id_map.empty()) {
414  return id_map;
415  }
416  const auto num_untranslated_strings = id_map.numUntranslatedStrings();
417  if (num_untranslated_strings > 0) {
418  const size_t total_post_translation_dest_transients =
419  num_untranslated_strings + dest_proxy->transientEntryCountUnlocked();
420  constexpr size_t max_allowed_transients =
421  static_cast<size_t>(std::numeric_limits<int32_t>::max() -
422  2); /* -2 accounts for INVALID_STR_ID and NULL value */
423  if (total_post_translation_dest_transients > max_allowed_transients) {
424  throw std::runtime_error("Union translation to dictionary" +
425  std::to_string(getDictId()) + " would result in " +
426  std::to_string(total_post_translation_dest_transients) +
427  " transient entries, which is more than limit of " +
428  std::to_string(max_allowed_transients) + " transients.");
429  }
430  const int32_t map_domain_start = id_map.domainStart();
431  const int32_t map_domain_end = id_map.domainEnd();
432 
433  const StringOps_Namespace::StringOps string_ops(string_op_infos);
434  const bool has_string_ops = string_ops.size();
435 
436  // First iterate over transient strings and add to dest map
437  // Todo (todd): Add call to fetch string_views (local) or strings (distributed)
438  // for all non-translated ids to avoid string-by-string fetch
439 
440  for (int32_t source_string_id = map_domain_start; source_string_id < -1;
441  ++source_string_id) {
442  if (id_map[source_string_id] == StringDictionary::INVALID_STR_ID) {
443  const auto source_string = getStringUnlocked(source_string_id);
444  const auto dest_string_id = dest_proxy->getOrAddTransientUnlocked(
445  has_string_ops ? string_ops(source_string) : source_string);
446  id_map[source_string_id] = dest_string_id;
447  }
448  }
449  // Now iterate over stored strings
450  for (int32_t source_string_id = 0; source_string_id < map_domain_end;
451  ++source_string_id) {
452  if (id_map[source_string_id] == StringDictionary::INVALID_STR_ID) {
453  const auto source_string = string_dict_->getString(source_string_id);
454  const auto dest_string_id = dest_proxy->getOrAddTransientUnlocked(
455  has_string_ops ? string_ops(source_string) : source_string);
456  id_map[source_string_id] = dest_string_id;
457  }
458  }
459  }
460  // We may have added transients to the destination proxy, use this to update
461  // our id map range (used downstream for ExpressionRange)
462 
463  const size_t num_dest_transients = dest_proxy->transientEntryCountUnlocked();
464  id_map.setRangeStart(
465  num_dest_transients > 0 ? -1 - static_cast<int32_t>(num_dest_transients) : 0);
466  return id_map;
467 }
468 
469 namespace {
470 
471 bool is_like(const std::string& str,
472  const std::string& pattern,
473  const bool icase,
474  const bool is_simple,
475  const char escape) {
476  return icase
477  ? (is_simple ? string_ilike_simple(
478  str.c_str(), str.size(), pattern.c_str(), pattern.size())
479  : string_ilike(str.c_str(),
480  str.size(),
481  pattern.c_str(),
482  pattern.size(),
483  escape))
484  : (is_simple ? string_like_simple(
485  str.c_str(), str.size(), pattern.c_str(), pattern.size())
486  : string_like(str.c_str(),
487  str.size(),
488  pattern.c_str(),
489  pattern.size(),
490  escape));
491 }
492 
493 } // namespace
494 
495 std::vector<int32_t> StringDictionaryProxy::getLike(const std::string& pattern,
496  const bool icase,
497  const bool is_simple,
498  const char escape) const {
499  CHECK_GE(generation_, 0);
500  auto result = string_dict_->getLike(pattern, icase, is_simple, escape, generation_);
501  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
502  if (is_like(*transient_string_vec_[index], pattern, icase, is_simple, escape)) {
503  result.push_back(transientIndexToId(index));
504  }
505  }
506  return result;
507 }
508 
509 namespace {
510 
511 bool do_compare(const std::string& str,
512  const std::string& pattern,
513  const std::string& comp_operator) {
514  int res = str.compare(pattern);
515  if (comp_operator == "<") {
516  return res < 0;
517  } else if (comp_operator == "<=") {
518  return res <= 0;
519  } else if (comp_operator == "=") {
520  return res == 0;
521  } else if (comp_operator == ">") {
522  return res > 0;
523  } else if (comp_operator == ">=") {
524  return res >= 0;
525  } else if (comp_operator == "<>") {
526  return res != 0;
527  }
528  throw std::runtime_error("unsupported string compare operator");
529 }
530 
531 } // namespace
532 
534  const std::string& pattern,
535  const std::string& comp_operator) const {
536  CHECK_GE(generation_, 0);
537  auto result = string_dict_->getCompare(pattern, comp_operator, generation_);
538  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
539  if (do_compare(*transient_string_vec_[index], pattern, comp_operator)) {
540  result.push_back(transientIndexToId(index));
541  }
542  }
543  return result;
544 }
545 
546 namespace {
547 
548 bool is_regexp_like(const std::string& str,
549  const std::string& pattern,
550  const char escape) {
551  return regexp_like(str.c_str(), str.size(), pattern.c_str(), pattern.size(), escape);
552 }
553 
554 } // namespace
555 
556 std::vector<int32_t> StringDictionaryProxy::getRegexpLike(const std::string& pattern,
557  const char escape) const {
558  CHECK_GE(generation_, 0);
559  auto result = string_dict_->getRegexpLike(pattern, escape, generation_);
560  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
561  if (is_regexp_like(*transient_string_vec_[index], pattern, escape)) {
562  result.push_back(transientIndexToId(index));
563  }
564  }
565  return result;
566 }
567 
568 int32_t StringDictionaryProxy::getOrAdd(const std::string& str) noexcept {
569  return string_dict_->getOrAdd(str);
570 }
571 
572 std::pair<const char*, size_t> StringDictionaryProxy::getStringBytes(
573  int32_t string_id) const noexcept {
574  if (string_id >= 0) {
575  return string_dict_.get()->getStringBytes(string_id);
576  }
577  unsigned const string_index = transientIdToIndex(string_id);
578  CHECK_LT(string_index, transient_string_vec_.size());
579  std::string const* const str_ptr = transient_string_vec_[string_index];
580  return {str_ptr->c_str(), str_ptr->size()};
581 }
582 
584  const size_t num_storage_entries{generation_ == -1 ? string_dict_->storageEntryCount()
585  : generation_};
586  CHECK_LE(num_storage_entries, static_cast<size_t>(std::numeric_limits<int32_t>::max()));
587  return num_storage_entries;
588 }
589 
591  // CHECK_LE(num_storage_entries,
592  // static_cast<size_t>(std::numeric_limits<int32_t>::max()));
593  const size_t num_transient_entries{transient_str_to_int_.size()};
594  CHECK_LE(num_transient_entries,
595  static_cast<size_t>(std::numeric_limits<int32_t>::max()) - 1);
596  return num_transient_entries;
597 }
598 
600  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
602 }
603 
606 }
607 
609  std::shared_lock<std::shared_mutex> read_lock(rw_mutex_);
610  return entryCountUnlocked();
611 }
612 
613 // Iterate over transient strings, then non-transients.
615  StringDictionary::StringCallback& serial_callback) const {
616  constexpr int32_t max_transient_id = -2;
617  // Iterate over transient strings.
618  for (unsigned index = 0; index < transient_string_vec_.size(); ++index) {
619  std::string const& str = *transient_string_vec_[index];
620  int32_t const string_id = max_transient_id - index;
621  serial_callback(str, string_id);
622  }
623  // Iterate over non-transient strings.
624  string_dict_->eachStringSerially(generation_, serial_callback);
625 }
626 
627 // For each (string/_view,old_id) pair passed in:
628 // * Get the new_id based on sdp_'s dictionary, or add it as a transient.
629 // * The StringDictionary is local, so call the faster getUnlocked() method.
630 // * Store the old_id -> new_id translation into the id_map_.
634 
635  public:
637  : sdp_(sdp), id_map_(id_map) {}
638  void operator()(std::string const& str, int32_t const string_id) override {
639  operator()(std::string_view(str), string_id);
640  }
641  void operator()(std::string_view const sv, int32_t const old_id) override {
642  int32_t const new_id = sdp_->string_dict_->getUnlocked(sv);
643  id_map_[old_id] = new_id == StringDictionary::INVALID_STR_ID
645  : new_id;
646  }
647 };
648 
649 // For each (string,old_id) pair passed in:
650 // * Get the new_id based on sdp_'s dictionary, or add it as a transient.
651 // * The StringDictionary is not local, so call string_dict_->makeLambdaStringToId()
652 // to make a lookup hash.
653 // * Store the old_id -> new_id translation into the id_map_.
657  using Lambda = std::function<int32_t(std::string const&)>;
659 
660  public:
662  : sdp_(sdp)
663  , id_map_(id_map)
664  , string_to_id_(sdp->string_dict_->makeLambdaStringToId()) {}
665  void operator()(std::string const& str, int32_t const old_id) override {
666  int32_t const new_id = string_to_id_(str);
667  id_map_[old_id] = new_id == StringDictionary::INVALID_STR_ID
669  : new_id;
670  }
671  void operator()(std::string_view const, int32_t const string_id) override {
672  UNREACHABLE() << "StringNetworkCallback requires a std::string.";
673  }
674 };
675 
676 // Union strings from both StringDictionaryProxies into *this as transients.
677 // Return id_map: sdp_rhs:string_id -> this:string_id for each string in sdp_rhs.
679  StringDictionaryProxy const& sdp_rhs) {
680  IdMap id_map = sdp_rhs.initIdMap();
681  // serial_callback cannot be parallelized due to calling getOrAddTransientUnlocked().
682  std::unique_ptr<StringDictionary::StringCallback> serial_callback;
683  if (string_dict_->isClient()) {
684  serial_callback = std::make_unique<StringNetworkCallback>(this, id_map);
685  } else {
686  serial_callback = std::make_unique<StringLocalCallback>(this, id_map);
687  }
688  // Import all non-duplicate strings (transient and non-transient) and add to id_map.
689  sdp_rhs.eachStringSerially(*serial_callback);
690  return id_map;
691 }
692 
693 void StringDictionaryProxy::updateGeneration(const int64_t generation) noexcept {
694  if (generation == -1) {
695  return;
696  }
697  if (generation_ != -1) {
698  CHECK_EQ(generation_, generation);
699  return;
700  }
701  generation_ = generation;
702 }
703 
705  const std::vector<std::string>& strings,
706  int32_t* string_ids,
707  const bool take_read_lock) const {
708  const size_t num_strings = strings.size();
709  if (num_strings == 0) {
710  return 0UL;
711  }
712  // StringDictionary::getBulk returns the number of strings not found
713  if (string_dict_->getBulk(strings, string_ids, generation_) == 0UL) {
714  return 0UL;
715  }
716 
717  // If here, dictionary could not find at least 1 target string,
718  // now look these up in the transient dictionary
719  // transientLookupBulk returns the number of strings not found
720  return transientLookupBulk(strings, string_ids, take_read_lock);
721 }
722 
723 template <typename String>
725  const std::vector<String>& lookup_strings,
726  int32_t* string_ids,
727  const bool take_read_lock) const {
728  const size_t num_strings = lookup_strings.size();
729  auto read_lock = take_read_lock ? std::shared_lock<std::shared_mutex>(rw_mutex_)
730  : std::shared_lock<std::shared_mutex>();
731 
732  if (num_strings == static_cast<size_t>(0) || transient_str_to_int_.empty()) {
733  return 0UL;
734  }
735  constexpr size_t tbb_parallel_threshold{20000};
736  if (num_strings < tbb_parallel_threshold) {
737  return transientLookupBulkUnlocked(lookup_strings, string_ids);
738  } else {
739  return transientLookupBulkParallelUnlocked(lookup_strings, string_ids);
740  }
741 }
742 
743 template <typename String>
745  const std::vector<String>& lookup_strings,
746  int32_t* string_ids) const {
747  const size_t num_strings = lookup_strings.size();
748  size_t num_strings_not_found = 0;
749  for (size_t string_idx = 0; string_idx < num_strings; ++string_idx) {
750  if (string_ids[string_idx] != StringDictionary::INVALID_STR_ID) {
751  continue;
752  }
753  // If we're here it means we need to look up this string as we don't
754  // have a valid id for it
755  string_ids[string_idx] = lookupTransientStringUnlocked(lookup_strings[string_idx]);
756  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
757  num_strings_not_found++;
758  }
759  }
760  return num_strings_not_found;
761 }
762 
763 template <typename String>
765  const std::vector<String>& lookup_strings,
766  int32_t* string_ids) const {
767  const size_t num_lookup_strings = lookup_strings.size();
768  const size_t target_inputs_per_thread = 20000L;
769  ThreadInfo thread_info(
770  std::thread::hardware_concurrency(), num_lookup_strings, target_inputs_per_thread);
771  CHECK_GE(thread_info.num_threads, 1L);
772  CHECK_GE(thread_info.num_elems_per_thread, 1L);
773 
774  std::vector<size_t> num_strings_not_found_per_thread(thread_info.num_threads, 0UL);
775 
776  tbb::task_arena limited_arena(thread_info.num_threads);
777  limited_arena.execute([&] {
779  tbb::blocked_range<size_t>(
780  0, num_lookup_strings, thread_info.num_elems_per_thread /* tbb grain_size */),
781  [&](const tbb::blocked_range<size_t>& r) {
782  const size_t start_idx = r.begin();
783  const size_t end_idx = r.end();
784  size_t num_local_strings_not_found = 0;
785  for (size_t string_idx = start_idx; string_idx < end_idx; ++string_idx) {
786  if (string_ids[string_idx] != StringDictionary::INVALID_STR_ID) {
787  continue;
788  }
789  string_ids[string_idx] =
790  lookupTransientStringUnlocked(lookup_strings[string_idx]);
791  if (string_ids[string_idx] == StringDictionary::INVALID_STR_ID) {
792  num_local_strings_not_found++;
793  }
794  }
795  const size_t tbb_thread_idx = tbb::this_task_arena::current_thread_index();
796  num_strings_not_found_per_thread[tbb_thread_idx] = num_local_strings_not_found;
797  },
798  tbb::simple_partitioner());
799  });
800  size_t num_strings_not_found = 0;
801  for (int64_t thread_idx = 0; thread_idx < thread_info.num_threads; ++thread_idx) {
802  num_strings_not_found += num_strings_not_found_per_thread[thread_idx];
803  }
804  return num_strings_not_found;
805 }
806 
808  return string_dict_.get();
809 }
810 
811 int64_t StringDictionaryProxy::getGeneration() const noexcept {
812  return generation_;
813 }
814 
816  return string_dict_id_ == rhs.string_dict_id_ &&
818 }
819 
821  return !operator==(rhs);
822 }
void eachStringSerially(StringDictionary::StringCallback &) const
int32_t getOrAddTransientImpl(String)
void setNumUntranslatedStrings(const size_t num_untranslated_strings)
#define CHECK_EQ(x, y)
Definition: Logger.h:297
std::pair< const char *, size_t > getStringBytes(int32_t string_id) const noexcept
std::vector< int32_t > getLike(const std::string &pattern, const bool icase, const bool is_simple, const char escape) const
size_t transientEntryCountUnlocked() const
StringLocalCallback(StringDictionaryProxy *sdp, StringDictionaryProxy::IdMap &id_map)
int64_t num_elems_per_thread
Definition: ThreadInfo.h:23
StringDictionaryProxy::IdMap & id_map_
heavyai::shared_lock< heavyai::shared_mutex > read_lock
size_t entryCount() const
Returns the number of total string entries for this proxy, both stored in the underlying dictionary a...
int32_t getIdOfStringNoGeneration(const std::string &str) const
std::function< int32_t(std::string const &)> Lambda
std::string getStringUnlocked(const int32_t string_id) const
size_t storageEntryCount() const
Returns the number of string entries in the underlying string dictionary, at this proxy&#39;s generation_...
#define UNREACHABLE()
Definition: Logger.h:333
StringDictionary * getDictionary() const noexcept
#define CHECK_GE(x, y)
Definition: Logger.h:302
size_t transientLookupBulkUnlocked(const std::vector< String > &lookup_strings, int32_t *string_ids) const
StringDictionaryProxy * sdp_
void operator()(std::string const &str, int32_t const string_id) override
size_t transientLookupBulk(const std::vector< String > &lookup_strings, int32_t *string_ids, const bool take_read_lock) const
std::string getString(int32_t string_id) const
Constants for Builtin SQL Types supported by HEAVY.AI.
heavyai::unique_lock< heavyai::shared_mutex > write_lock
IdMap buildIntersectionTranslationMapToOtherProxyUnlocked(const StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
size_t transientLookupBulkParallelUnlocked(const std::vector< String > &lookup_strings, int32_t *string_ids) const
#define CHECK_GT(x, y)
Definition: Logger.h:301
int32_t getIdOfStringFromClient(String const &) const
std::vector< int32_t > getTransientBulk(const std::vector< std::string > &strings) const
Executes read-only lookup of a vector of strings and returns a vector of their integer ids...
std::string to_string(char const *&&v)
TranslationMap< Datum > buildNumericTranslationMap(const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
Builds a vectorized string_id translation map from this proxy to dest_proxy.
std::vector< int32_t > getCompare(const std::string &pattern, const std::string &comp_operator) const
#define DEVICE
bool is_regexp_like(const std::string &str, const std::string &pattern, const char escape)
StringNetworkCallback(StringDictionaryProxy *sdp, StringDictionaryProxy::IdMap &id_map)
static constexpr int32_t INVALID_STR_ID
std::shared_ptr< StringDictionary > string_dict_
int64_t num_threads
Definition: ThreadInfo.h:22
IdMap transientUnion(StringDictionaryProxy const &)
std::vector< std::string const * > transient_string_vec_
void setRangeEnd(const int32_t range_end)
void order_translation_locks(const int32_t source_db_id, const int32_t source_dict_id, const int32_t dest_db_id, const int32_t dest_dict_id, std::shared_lock< std::shared_mutex > &source_read_lock, std::shared_lock< std::shared_mutex > &dest_read_lock)
RUNTIME_EXPORT DEVICE bool string_like(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: StringLike.cpp:244
void operator()(std::string const &str, int32_t const old_id) override
int32_t lookupTransientStringUnlocked(const String &lookup_string) const
std::vector< std::string > getStrings(const std::vector< int32_t > &string_ids) const
size_t getTransientBulkImpl(const std::vector< std::string > &strings, int32_t *string_ids, const bool take_read_lock) const
RUNTIME_EXPORT DEVICE bool string_like_simple(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len)
Definition: StringLike.cpp:41
bool is_like(const std::string &str, const std::string &pattern, const bool icase, const bool is_simple, const char escape)
void operator()(std::string_view const sv, int32_t const old_id) override
static int32_t transientIndexToId(unsigned const index)
void updateGeneration(const int64_t generation) noexcept
size_t transientEntryCount() const
Returns the number of transient string entries for this proxy,.
OUTPUT transform(INPUT const &input, FUNC const &func)
Definition: misc.h:320
Functions to support the LIKE and ILIKE operator in SQL. Only single-byte character set is supported ...
IdMap buildUnionTranslationMapToOtherProxy(StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_types) const
StringDictionaryProxy(StringDictionaryProxy const &)=delete
void setRangeStart(const int32_t range_start)
int32_t getOrAddTransient(const std::string &)
#define RUNTIME_EXPORT
#define CHECK_LT(x, y)
Definition: Logger.h:299
void operator()(std::string_view const, int32_t const string_id) override
bool do_compare(const std::string &str, const std::string &pattern, const std::string &comp_operator)
#define CHECK_LE(x, y)
Definition: Logger.h:300
StringDictionaryProxy * sdp_
int32_t getOrAddTransientUnlocked(String const &)
bool operator!=(StringDictionaryProxy const &) const
std::vector< int32_t > getRegexpLike(const std::string &pattern, const char escape) const
int32_t getOrAdd(const std::string &str) noexcept
RUNTIME_EXPORT DEVICE bool string_ilike_simple(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len)
Definition: StringLike.cpp:57
bool operator==(StringDictionaryProxy const &) const
std::vector< T > const & getVectorMap() const
void parallel_for(const blocked_range< Int > &range, const Body &body, const Partitioner &p=Partitioner())
int32_t getDictId() const noexcept
std::vector< int32_t > getOrAddTransientBulk(const std::vector< std::string > &strings)
IdMap buildIntersectionTranslationMapToOtherProxy(const StringDictionaryProxy *dest_proxy, const std::vector< StringOps_Namespace::StringOpInfo > &string_op_infos) const
#define CHECK(condition)
Definition: Logger.h:289
DEVICE RUNTIME_EXPORT int32_t StringDictionaryProxy_getStringId(int8_t *proxy_ptr, char *c_str_ptr)
#define DEBUG_TIMER(name)
Definition: Logger.h:407
DEVICE RUNTIME_EXPORT size_t StringDictionaryProxy_getStringLength(int8_t *proxy_ptr, int32_t string_id)
Definition: Datum.h:67
RUNTIME_EXPORT DEVICE bool regexp_like(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: Regexp.cpp:39
int32_t getIdOfString(const std::string &str) const
static unsigned transientIdToIndex(int32_t const id)
int64_t getGeneration() const noexcept
#define VLOG(n)
Definition: Logger.h:383
int32_t truncate_to_generation(const int32_t id, const size_t generation)
DEVICE RUNTIME_EXPORT const char * StringDictionaryProxy_getStringBytes(int8_t *proxy_ptr, int32_t string_id)
StringDictionaryProxy::IdMap & id_map_
RUNTIME_EXPORT DEVICE bool string_ilike(const char *str, const int32_t str_len, const char *pattern, const int32_t pat_len, const char escape_char)
Definition: StringLike.cpp:255