OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MLPredictCodegen.cpp File Reference
#include "CodeGenerator.h"
#include "QueryEngine/TableFunctions/SystemFunctions/os/ML/MLModel.h"
#include "TreeModelPredictionMgr.h"
#include <tbb/parallel_for.h>
#include <stack>
#include <vector>
+ Include dependency graph for MLPredictCodegen.cpp:

Go to the source code of this file.

Functions

std::vector< std::shared_ptr
< Analyzer::Expr > > 
generated_encoded_and_casted_features (const std::vector< std::shared_ptr< Analyzer::Expr >> &feature_exprs, const std::vector< std::vector< std::string >> &cat_feature_keys, const std::vector< int64_t > &feature_permutations, Executor *executor)
 

Function Documentation

std::vector<std::shared_ptr<Analyzer::Expr> > generated_encoded_and_casted_features ( const std::vector< std::shared_ptr< Analyzer::Expr >> &  feature_exprs,
const std::vector< std::vector< std::string >> &  cat_feature_keys,
const std::vector< int64_t > &  feature_permutations,
Executor executor 
)

Definition at line 30 of file MLPredictCodegen.cpp.

References CHECK, Datum::doubleval, Datum::intval, kBOOLEAN, kCAST, kDOUBLE, kEQ, kINT, kISNULL, and kONE.

Referenced by CodeGenerator::codegen(), CodeGenerator::codegenLinRegPredict(), and CodeGenerator::codegenTreeRegPredict().

34  {
35  std::vector<std::shared_ptr<Analyzer::Expr>> casted_feature_exprs;
36  const size_t num_feature_exprs = feature_exprs.size();
37  const size_t num_cat_features = cat_feature_keys.size();
38 
39  if (num_cat_features > num_feature_exprs) {
40  throw std::runtime_error("More categorical keys than features.");
41  }
42 
43  auto get_int_constant_expr = [](int32_t const_val) {
44  Datum d;
45  d.intval = const_val;
46  return makeExpr<Analyzer::Constant>(SQLTypeInfo(kINT, false), false, d);
47  };
48 
49  for (size_t original_feature_idx = 0; original_feature_idx < num_feature_exprs;
50  ++original_feature_idx) {
51  const auto feature_idx = feature_permutations.empty()
52  ? original_feature_idx
53  : feature_permutations[original_feature_idx];
54  auto& feature_expr = feature_exprs[feature_idx];
55  const auto& feature_ti = feature_expr->get_type_info();
56  if (feature_ti.is_number()) {
57  // Don't conditionally cast to double iff type is not double
58  // as this was causing issues for the random forest function with
59  // mixed types. Need to troubleshoot more but always casting to double
60  // regardless of the underlying type always seems to be safe
61  casted_feature_exprs.emplace_back(makeExpr<Analyzer::UOper>(
62  SQLTypeInfo(kDOUBLE, false), false, kCAST, feature_expr));
63  } else {
64  CHECK(feature_ti.is_string()) << "Expected text type";
65  if (!feature_ti.is_text_encoding_dict()) {
66  throw std::runtime_error("Expected dictionary-encoded text column.");
67  }
68  if (original_feature_idx >= num_cat_features) {
69  throw std::runtime_error("Model not trained on text type for column.");
70  }
71  const auto& str_dict_key = feature_ti.getStringDictKey();
72  const auto str_dict_proxy = executor->getStringDictionaryProxy(str_dict_key, true);
73  for (const auto& cat_feature_key : cat_feature_keys[original_feature_idx]) {
74  // For one-hot encoded columns, null values will translate as a 0.0 and not a null
75  // We are computing the following:
76  // CASE WHEN str_val is NULL then 0.0 ELSE
77  // CAST(str_id = one_hot_encoded_str_id AS DOUBLE) END
78 
79  // Check if the expression is null
80  auto is_null_expr = makeExpr<Analyzer::UOper>(
81  SQLTypeInfo(kBOOLEAN, false), false, kISNULL, feature_expr);
82  Datum zero_datum;
83  zero_datum.doubleval = 0.0;
84  // If null then emit a 0.0 double constant as the THEN expr
85  auto is_null_then_expr =
86  makeExpr<Analyzer::Constant>(SQLTypeInfo(kDOUBLE, false), false, zero_datum);
87  std::list<
88  std::pair<std::shared_ptr<Analyzer::Expr>, std::shared_ptr<Analyzer::Expr>>>
89  when_then_exprs;
90  when_then_exprs.emplace_back(std::make_pair(is_null_expr, is_null_then_expr));
91  // The rest of/core string test logic goes in the ELSE statement
92  // Get the string id of the one-hot feature
93  const auto str_id = str_dict_proxy->getIdOfString(cat_feature_key);
94  auto str_id_expr = get_int_constant_expr(str_id);
95  // Get integer id for this row's string
96  auto key_for_string_expr = makeExpr<Analyzer::KeyForStringExpr>(feature_expr);
97 
98  // Check if this row's string id is equal to the search one-hot encoded id
99  std::shared_ptr<Analyzer::Expr> str_equality_expr =
100  makeExpr<Analyzer::BinOper>(SQLTypeInfo(kBOOLEAN, false),
101  false,
102  kEQ,
103  kONE,
104  key_for_string_expr,
105  str_id_expr);
106  // Cast the above boolean results to a double, 0.0 or 1.0
107  auto cast_expr = makeExpr<Analyzer::UOper>(
108  SQLTypeInfo(kDOUBLE, false), false, kCAST, str_equality_expr);
109 
110  // Generate the full CASE statement and add to the casted feature exprssions
111  casted_feature_exprs.emplace_back(makeExpr<Analyzer::CaseExpr>(
112  SQLTypeInfo(kDOUBLE, false), false, when_then_exprs, cast_expr));
113  }
114  }
115  }
116  return casted_feature_exprs;
117 }
Definition: sqldefs.h:48
Definition: sqldefs.h:29
int32_t intval
Definition: Datum.h:73
Definition: sqldefs.h:71
#define CHECK(condition)
Definition: Logger.h:291
Definition: sqltypes.h:72
Definition: Datum.h:69
double doubleval
Definition: Datum.h:76

+ Here is the caller graph for this function: