OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Parser::CreateModelStmt Class Reference

#include <ParserNode.h>

+ Inheritance diagram for Parser::CreateModelStmt:
+ Collaboration diagram for Parser::CreateModelStmt:

Public Member Functions

 CreateModelStmt (const rapidjson::Value &payload)
 
const std::string & get_model_name () const
 
const std::string & get_select_query () const
 
void execute (const Catalog_Namespace::SessionInfo &session, bool read_only_mode) override
 
void train_model (const Catalog_Namespace::SessionInfo &session)
 
- Public Member Functions inherited from Parser::DDLStmt
void setColumnDescriptor (ColumnDescriptor &cd, const ColumnDef *coldef)
 
- Public Member Functions inherited from Parser::Node
virtual ~Node ()
 

Private Member Functions

bool check_model_exists ()
 
void parse_model_options ()
 
std::string build_model_query (const std::shared_ptr< Catalog_Namespace::SessionInfo > session_ptr)
 

Private Attributes

MLModelType model_type_
 
std::string model_name_
 
std::string select_query_
 
bool replace_
 
bool if_not_exists_
 
std::list< std::unique_ptr
< NameValueAssign > > 
model_options_
 
std::ostringstream options_oss_
 
size_t num_options_ {0}
 
double data_split_train_fraction_ {1.0}
 
double data_split_eval_fraction_ {0.0}
 
std::string model_predicted_var_
 
std::vector< std::string > model_feature_vars_
 
std::vector< int64_t > feature_permutations_
 

Detailed Description

Definition at line 1959 of file ParserNode.h.

Constructor & Destructor Documentation

Parser::CreateModelStmt::CreateModelStmt ( const rapidjson::Value &  payload)

Definition at line 3439 of file ParserNode.cpp.

References CHECK, g_enable_ml_functions, get_ml_model_type_from_str(), if_not_exists_, json_bool(), json_str(), model_name_, model_options_, model_type_, Parser::anonymous_namespace{ParserNode.cpp}::parse_options(), replace_, and select_query_.

3439  {
3440  if (!g_enable_ml_functions) {
3441  throw std::runtime_error("Cannot create model. ML functions are disabled.");
3442  }
3443  CHECK(payload.HasMember("name"));
3444  const std::string model_type_str = json_str(payload["type"]);
3445  model_type_ = get_ml_model_type_from_str(model_type_str);
3446  model_name_ = json_str(payload["name"]);
3447  replace_ = false;
3448  if (payload.HasMember("replace")) {
3449  replace_ = json_bool(payload["replace"]);
3450  }
3451 
3452  if_not_exists_ = false;
3453  if (payload.HasMember("ifNotExists")) {
3454  if_not_exists_ = json_bool(payload["ifNotExists"]);
3455  }
3456 
3457  CHECK(payload.HasMember("query"));
3458  select_query_ = json_str(payload["query"]);
3459  std::regex newline_re("\\n");
3460  std::regex backtick_re("`");
3461  select_query_ = std::regex_replace(select_query_, newline_re, " ");
3462  select_query_ = std::regex_replace(select_query_, backtick_re, "");
3463 
3464  // No need to ensure trailing semicolon as we will wrap this select statement
3465  // in a CURSOR as input to the train model table function
3466  parse_options(payload, model_options_);
3467 }
std::list< std::unique_ptr< NameValueAssign > > model_options_
Definition: ParserNode.h:1975
const bool json_bool(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:51
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:46
std::string select_query_
Definition: ParserNode.h:1972
void parse_options(const rapidjson::Value &payload, std::list< std::unique_ptr< NameValueAssign >> &nameValueList, bool stringToNull=false, bool stringToInteger=false)
bool g_enable_ml_functions
Definition: Execute.cpp:118
#define CHECK(condition)
Definition: Logger.h:291
MLModelType get_ml_model_type_from_str(const std::string &model_type_str)
Definition: MLModelType.h:52

+ Here is the call graph for this function:

Member Function Documentation

std::string Parser::CreateModelStmt::build_model_query ( const std::shared_ptr< Catalog_Namespace::SessionInfo session_ptr)
private

Definition at line 3643 of file ParserNode.cpp.

References query_state::QueryState::create(), data_split_train_fraction_, feature_permutations_, Parser::LocalQueryConnector::getColumnDescriptors(), is_regression_model(), model_feature_vars_, model_predicted_var_, model_type_, Parser::LocalQueryConnector::query(), and select_query_.

Referenced by train_model().

3644  {
3645  auto validate_query_state = query_state::QueryState::create(session_ptr, select_query_);
3646 
3647  LocalQueryConnector local_connector;
3648 
3649  auto validate_result = local_connector.query(
3650  validate_query_state->createQueryStateProxy(), select_query_, {}, true, false);
3651 
3652  auto column_descriptors_for_model_create =
3653  local_connector.getColumnDescriptors(validate_result, true);
3654 
3655  std::vector<size_t> categorical_feature_idxs;
3656  std::vector<size_t> numeric_feature_idxs;
3657  bool numeric_feature_seen = false;
3658  bool all_categorical_features_placed_first = true;
3659  bool model_has_predicted_var = is_regression_model(model_type_);
3660  model_feature_vars_.reserve(column_descriptors_for_model_create.size() -
3661  (model_has_predicted_var ? 1 : 0));
3662  bool is_predicted = model_has_predicted_var ? true : false;
3663  size_t feature_idx = 0;
3664  for (auto& cd : column_descriptors_for_model_create) {
3665  // Check to see if the projected column is an expression without a user-provided
3666  // alias, as we don't allow this.
3667  if (cd.columnName.rfind("EXPR$", 0) == 0) {
3668  throw std::runtime_error(
3669  "All projected expressions (i.e. col * 2) that are not column references (i.e. "
3670  "col) must be aliased.");
3671  }
3672  if (is_predicted) {
3673  model_predicted_var_ = cd.columnName;
3674  if (!cd.columnType.is_number()) {
3675  throw std::runtime_error(
3676  "Numeric predicted column expression should be first argument to CREATE "
3677  "MODEL.");
3678  }
3679  is_predicted = false;
3680  } else {
3681  if (cd.columnType.is_number()) {
3682  numeric_feature_idxs.emplace_back(feature_idx);
3683  numeric_feature_seen = true;
3684  } else if (cd.columnType.is_string()) {
3685  categorical_feature_idxs.emplace_back(feature_idx);
3686  if (numeric_feature_seen) {
3687  all_categorical_features_placed_first = false;
3688  }
3689  } else {
3690  throw std::runtime_error("Feature column expression should be numeric or TEXT.");
3691  }
3692  model_feature_vars_.emplace_back(cd.columnName);
3693  feature_idx++;
3694  }
3695  }
3696  auto modified_select_query = select_query_;
3697  if (!all_categorical_features_placed_first) {
3698  std::ostringstream modified_query_oss;
3699  modified_query_oss << "SELECT ";
3700  if (model_has_predicted_var) {
3701  modified_query_oss << model_predicted_var_ << ", ";
3702  }
3703  for (auto categorical_feature_idx : categorical_feature_idxs) {
3704  modified_query_oss << model_feature_vars_[categorical_feature_idx] << ", ";
3705  feature_permutations_.emplace_back(static_cast<int64_t>(categorical_feature_idx));
3706  }
3707  for (auto numeric_feature_idx : numeric_feature_idxs) {
3708  modified_query_oss << model_feature_vars_[numeric_feature_idx];
3709  feature_permutations_.emplace_back(static_cast<int64_t>(numeric_feature_idx));
3710  if (numeric_feature_idx != numeric_feature_idxs.back()) {
3711  modified_query_oss << ", ";
3712  }
3713  }
3714  modified_query_oss << " FROM (" << modified_select_query << ")";
3715  modified_select_query = modified_query_oss.str();
3716  }
3717 
3718  if (data_split_train_fraction_ < 1.0) {
3719  std::ostringstream modified_query_oss;
3720  if (all_categorical_features_placed_first) {
3721  modified_query_oss << "SELECT * FROM (" << modified_select_query << ")";
3722  } else {
3723  modified_query_oss << modified_select_query;
3724  }
3725  modified_query_oss << " WHERE SAMPLE_RATIO(" << data_split_train_fraction_ << ")";
3726  modified_select_query = modified_query_oss.str();
3727  }
3728  return modified_select_query;
3729 }
static std::shared_ptr< QueryState > create(ARGS &&...args)
Definition: QueryState.h:148
std::string model_predicted_var_
Definition: ParserNode.h:1980
std::vector< std::string > model_feature_vars_
Definition: ParserNode.h:1981
std::vector< int64_t > feature_permutations_
Definition: ParserNode.h:1982
std::string select_query_
Definition: ParserNode.h:1972
bool is_regression_model(const MLModelType model_type)
Definition: MLModelType.h:69

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

bool Parser::CreateModelStmt::check_model_exists ( )
private

Definition at line 3529 of file ParserNode.cpp.

References g_ml_models, get_model_name(), if_not_exists_, MLModelMap::modelExists(), and replace_.

Referenced by train_model().

3529  {
3531  if (if_not_exists_) {
3532  // Returning true tells the caller we should just return early and silently (without
3533  // error)
3534  return true;
3535  }
3536  if (!replace_) {
3537  std::ostringstream error_oss;
3538  error_oss << "Model " << get_model_name() << " already exists.";
3539  throw std::runtime_error(error_oss.str());
3540  }
3541  }
3542  // Returning false tells the caller all is clear to proceed with the create model,
3543  // whether that means creating a new one or overwriting an existing model
3544  return false;
3545 }
const std::string & get_model_name() const
Definition: ParserNode.h:1963
bool modelExists(const std::string &model_name) const
Definition: MLModel.h:43
MLModelMap g_ml_models
Definition: MLModel.h:124

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void Parser::CreateModelStmt::execute ( const Catalog_Namespace::SessionInfo session,
bool  read_only_mode 
)
overridevirtual

Implements Parser::DDLStmt.

Definition at line 3790 of file ParserNode.cpp.

References model_name_, and train_model().

Referenced by heavydb.cursor.Cursor::executemany().

3791  {
3792  if (read_only_mode) {
3793  throw std::runtime_error("CREATE MODEL invalid in read only mode.");
3794  }
3795 
3796  try {
3797  train_model(session);
3798  } catch (std::exception& e) {
3799  std::ostringstream error_oss;
3800  // Error messages from table functions come back like this:
3801  // Error executing table function: MLTableFunctions.hpp:269 linear_reg_fit_impl: No
3802  // rows exist in training input. Training input must at least contain 1 row.
3803 
3804  // We want to take everything after the function name, so we will search for the
3805  // third colon.
3806  // Todo(todd): Look at making this less hacky by setting a mode for the table
3807  // function that will return only the core error string and not the preprending
3808  // metadata
3809 
3810  auto get_error_substring = [](const std::string& message) -> std::string {
3811  size_t colon_position = std::string::npos;
3812  for (int i = 0; i < 3; ++i) {
3813  colon_position = message.find(':', colon_position + 1);
3814  if (colon_position == std::string::npos) {
3815  return message;
3816  }
3817  }
3818 
3819  if (colon_position + 2 >= message.length()) {
3820  return message;
3821  }
3822  return message.substr(colon_position + 2);
3823  };
3824 
3825  const auto error_substr = get_error_substring(e.what());
3826 
3827  error_oss << "Could not create model " << model_name_ << ". " << error_substr;
3828  throw std::runtime_error(error_oss.str());
3829  }
3830 }
void train_model(const Catalog_Namespace::SessionInfo &session)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

const std::string& Parser::CreateModelStmt::get_model_name ( ) const
inline

Definition at line 1963 of file ParserNode.h.

References model_name_.

Referenced by check_model_exists(), and train_model().

1963 { return model_name_; }

+ Here is the caller graph for this function:

const std::string& Parser::CreateModelStmt::get_select_query ( ) const
inline

Definition at line 1964 of file ParserNode.h.

References select_query_.

1964 { return select_query_; }
std::string select_query_
Definition: ParserNode.h:1972
void Parser::CreateModelStmt::parse_model_options ( )
private

Definition at line 3547 of file ParserNode.cpp.

References data_split_eval_fraction_, data_split_train_fraction_, Parser::DoubleLiteral::get_doubleval(), Parser::IntLiteral::get_intval(), Parser::StringLiteral::get_stringval(), model_options_, num_options_, and options_oss_.

Referenced by train_model().

3547  {
3548  bool train_fraction_specified = false;
3549  bool eval_fraction_specified = false;
3550  for (auto& p : model_options_) {
3551  const auto key = boost::to_lower_copy<std::string>(*p->get_name());
3552  if (key == "train_fraction" || key == "data_split_train_fraction") {
3553  if (train_fraction_specified) {
3554  throw std::runtime_error(
3555  "Error parsing DATA_SPLIT_TRAIN_FRACTION value. "
3556  "Expected only one value.");
3557  }
3558  const DoubleLiteral* fp_literal =
3559  dynamic_cast<const DoubleLiteral*>(p->get_value());
3560  if (fp_literal != nullptr) {
3561  data_split_train_fraction_ = fp_literal->get_doubleval();
3562  if (data_split_train_fraction_ <= 0.0 || data_split_train_fraction_ > 1.0) {
3563  throw std::runtime_error(
3564  "Error parsing DATA_SPLIT_TRAIN_FRACTION value. "
3565  "Expected value between 0.0 and 1.0.");
3566  }
3567  } else {
3568  throw std::runtime_error(
3569  "Error parsing DATA_SPLIT_TRAIN_FRACTION value. "
3570  "Expected floating point value betwen 0.0 and 1.0.");
3571  }
3572  train_fraction_specified = true;
3573  continue;
3574  }
3575  if (key == "eval_fraction" || key == "data_split_eval_fraction") {
3576  if (eval_fraction_specified) {
3577  throw std::runtime_error(
3578  "Error parsing DATA_SPLIT_EVAL_FRACTION value. "
3579  "Expected only one value.");
3580  }
3581  const DoubleLiteral* fp_literal =
3582  dynamic_cast<const DoubleLiteral*>(p->get_value());
3583  if (fp_literal != nullptr) {
3584  data_split_eval_fraction_ = fp_literal->get_doubleval();
3585  if (data_split_eval_fraction_ < 0.0 || data_split_eval_fraction_ >= 1.0) {
3586  throw std::runtime_error(
3587  "Error parsing DATA_SPLIT_EVAL_FRACTION value. "
3588  "Expected value between 0.0 and 1.0.");
3589  }
3590  } else {
3591  throw std::runtime_error(
3592  "Error parsing DATA_SPLIT_EVAL_FRACTION value. "
3593  "Expected floating point value betwen 0.0 and 1.0.");
3594  }
3595  eval_fraction_specified = true;
3596  continue;
3597  }
3598  if (num_options_) {
3599  options_oss_ << ", ";
3600  }
3601  num_options_++;
3602  options_oss_ << key << " => ";
3603  const StringLiteral* str_literal = dynamic_cast<const StringLiteral*>(p->get_value());
3604  if (str_literal != nullptr) {
3605  options_oss_ << "'"
3606  << boost::to_lower_copy<std::string>(*str_literal->get_stringval())
3607  << "'";
3608  continue;
3609  }
3610  const IntLiteral* int_literal = dynamic_cast<const IntLiteral*>(p->get_value());
3611  if (int_literal != nullptr) {
3612  options_oss_ << int_literal->get_intval();
3613  continue;
3614  }
3615  const DoubleLiteral* fp_literal = dynamic_cast<const DoubleLiteral*>(p->get_value());
3616  if (fp_literal != nullptr) {
3617  options_oss_ << fp_literal->get_doubleval();
3618  continue;
3619  }
3620  throw std::runtime_error("Error parsing value.");
3621  }
3622 
3623  // First handle case where data_split_train_fraction was left to default value
3624  // and data_split_eval_fraction was specified. We shouldn't error here,
3625  // but rather set data_split_train_fraction to 1.0 - data_split_eval_fraction
3626  // Likewise if data_split_eval_fraction was left to default value and we have
3627  // a specified data_split_train_fraction, we should set data_split_eval_fraction
3628  // to 1.0 - data_split_train_fraction
3631  } else if (data_split_eval_fraction_ == 0.0 && data_split_train_fraction_ < 1.0) {
3633  }
3634 
3635  // If data_split_train_fraction was specified, and data_split_train_fraction +
3636  // data_split_eval_fraction > 1.0, then we should error
3638  throw std::runtime_error(
3639  "Error parsing DATA_SPLIT_TRAIN_FRACTION and DATA_SPLIT_EVAL_FRACTION values. "
3640  "Expected sum of values to be less than or equal to 1.0.");
3641  }
3642 }
std::list< std::unique_ptr< NameValueAssign > > model_options_
Definition: ParserNode.h:1975
std::ostringstream options_oss_
Definition: ParserNode.h:1976

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void Parser::CreateModelStmt::train_model ( const Catalog_Namespace::SessionInfo session)

Definition at line 3731 of file ParserNode.cpp.

References build_model_query(), check_model_exists(), query_state::QueryState::create(), data_split_eval_fraction_, data_split_train_fraction_, shared::encode_base64(), feature_permutations_, get_ml_model_type_str(), get_model_name(), model_feature_vars_, model_predicted_var_, model_type_, num_options_, options_oss_, parse_model_options(), Parser::LocalQueryConnector::query(), select_query_, and Parser::write_model_params_to_json().

Referenced by execute().

3731  {
3732  if (check_model_exists()) {
3733  // Will return true if model exists and if_not_exists_ is true, in this
3734  // case we should return only
3735  return;
3736  }
3737 
3739 
3740  auto session_copy = session;
3741  auto session_ptr = std::shared_ptr<Catalog_Namespace::SessionInfo>(
3742  &session_copy, boost::null_deleter());
3743 
3744  // We need to do various manipulations on the raw select query, such
3745  // as adding in any sampling or feature permutation logic. All of this
3746  // work is encapsulated in build_model_query
3747 
3748  const auto modified_select_query = build_model_query(session_ptr);
3749 
3750  // We have to base64 encode the model metadata because depending on the query,
3751  // the training data can have single quotes that trips up the parsing of the combined
3752  // select query with this metadata embedded.
3753 
3754  // This is just a temporary workaround until we store this info in the Catalog
3755  // rather than in the stored model pointer itself (and have to pass the metadata
3756  // down through the table function call)
3757  const auto model_metadata =
3760  select_query_,
3764  if (num_options_) {
3765  // The options string does not have a trailing comma,
3766  // so add it
3767  options_oss_ << ", ";
3768  }
3769  options_oss_ << "model_metadata => '" << model_metadata << "'";
3770 
3771  const std::string options_str = options_oss_.str();
3772 
3773  const std::string model_train_func = get_ml_model_type_str(model_type_) + "_FIT";
3774 
3775  std::ostringstream model_query_oss;
3776  model_query_oss << "SELECT * FROM TABLE(" << model_train_func << "(model_name=>'"
3777  << get_model_name() << "', data=>CURSOR(" << modified_select_query
3778  << ")";
3779  model_query_oss << ", " << options_str;
3780  model_query_oss << "))";
3781 
3782  std::string wrapped_model_query = model_query_oss.str();
3783  auto query_state = query_state::QueryState::create(session_ptr, wrapped_model_query);
3784  // Don't need result back from query, as the query will create the model
3785  LocalQueryConnector local_connector;
3786  local_connector.query(
3787  query_state->createQueryStateProxy(), wrapped_model_query, {}, false);
3788 }
std::string get_ml_model_type_str(const MLModelType model_type)
Definition: MLModelType.h:27
const std::string & get_model_name() const
Definition: ParserNode.h:1963
static std::shared_ptr< QueryState > create(ARGS &&...args)
Definition: QueryState.h:148
std::string write_model_params_to_json(const std::string &predicted, const std::vector< std::string > &features, const std::string &training_query, const double data_split_train_fraction, const double data_split_eval_fraction, const std::vector< int64_t > &feature_permutations)
std::string model_predicted_var_
Definition: ParserNode.h:1980
std::vector< std::string > model_feature_vars_
Definition: ParserNode.h:1981
std::vector< int64_t > feature_permutations_
Definition: ParserNode.h:1982
std::ostringstream options_oss_
Definition: ParserNode.h:1976
std::string build_model_query(const std::shared_ptr< Catalog_Namespace::SessionInfo > session_ptr)
std::string select_query_
Definition: ParserNode.h:1972
static std::string encode_base64(const std::string &val)
Definition: base64.h:45

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

Member Data Documentation

double Parser::CreateModelStmt::data_split_eval_fraction_ {0.0}
private

Definition at line 1979 of file ParserNode.h.

Referenced by parse_model_options(), and train_model().

double Parser::CreateModelStmt::data_split_train_fraction_ {1.0}
private

Definition at line 1978 of file ParserNode.h.

Referenced by build_model_query(), parse_model_options(), and train_model().

std::vector<int64_t> Parser::CreateModelStmt::feature_permutations_
private

Definition at line 1982 of file ParserNode.h.

Referenced by build_model_query(), and train_model().

bool Parser::CreateModelStmt::if_not_exists_
private

Definition at line 1974 of file ParserNode.h.

Referenced by check_model_exists(), and CreateModelStmt().

std::vector<std::string> Parser::CreateModelStmt::model_feature_vars_
private

Definition at line 1981 of file ParserNode.h.

Referenced by build_model_query(), and train_model().

std::string Parser::CreateModelStmt::model_name_
private

Definition at line 1971 of file ParserNode.h.

Referenced by CreateModelStmt(), execute(), and get_model_name().

std::list<std::unique_ptr<NameValueAssign> > Parser::CreateModelStmt::model_options_
private

Definition at line 1975 of file ParserNode.h.

Referenced by CreateModelStmt(), and parse_model_options().

std::string Parser::CreateModelStmt::model_predicted_var_
private

Definition at line 1980 of file ParserNode.h.

Referenced by build_model_query(), and train_model().

MLModelType Parser::CreateModelStmt::model_type_
private

Definition at line 1970 of file ParserNode.h.

Referenced by build_model_query(), CreateModelStmt(), and train_model().

size_t Parser::CreateModelStmt::num_options_ {0}
private

Definition at line 1977 of file ParserNode.h.

Referenced by parse_model_options(), and train_model().

std::ostringstream Parser::CreateModelStmt::options_oss_
private

Definition at line 1976 of file ParserNode.h.

Referenced by parse_model_options(), and train_model().

bool Parser::CreateModelStmt::replace_
private

Definition at line 1973 of file ParserNode.h.

Referenced by check_model_exists(), and CreateModelStmt().

std::string Parser::CreateModelStmt::select_query_
private

Definition at line 1972 of file ParserNode.h.

Referenced by build_model_query(), CreateModelStmt(), get_select_query(), and train_model().


The documentation for this class was generated from the following files: