OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MLModelType.h
Go to the documentation of this file.
1 /*
2  * Copyright 2023 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include "Logger/Logger.h"
20 #include "Shared/StringTransform.h"
21 #include "Shared/misc.h"
22 
23 #include <string>
24 
26 
27 inline std::string get_ml_model_type_str(const MLModelType model_type) {
28  switch (model_type) {
30  return "LINEAR_REG";
31  }
33  return "DECISION_TREE_REG";
34  }
35  case MLModelType::GBT_REG: {
36  return "GBT_REG";
37  }
39  return "RANDOM_FOREST_REG";
40  }
41  case MLModelType::PCA: {
42  return "PCA";
43  }
44  default: {
45  CHECK(false) << "Unknown model type.";
46  // Satisfy compiler
47  return "LINEAR_REG";
48  }
49  }
50 }
51 
52 inline MLModelType get_ml_model_type_from_str(const std::string& model_type_str) {
53  const auto upper_model_type_str = to_upper(model_type_str);
54  if (upper_model_type_str == "LINEAR_REG") {
56  } else if (upper_model_type_str == "DECISION_TREE_REG") {
58  } else if (upper_model_type_str == "GBT_REG") {
59  return MLModelType::GBT_REG;
60  } else if (upper_model_type_str == "RANDOM_FOREST_REG") {
62  } else if (upper_model_type_str == "PCA") {
63  return MLModelType::PCA;
64  } else {
65  throw std::invalid_argument("Unknown model type: " + upper_model_type_str);
66  }
67 }
68 
69 inline bool is_regression_model(const MLModelType model_type) {
73  MLModelType::RANDOM_FOREST_REG>(model_type);
74 }
std::string get_ml_model_type_str(const MLModelType model_type)
Definition: MLModelType.h:27
bool is_any(T &&value)
Definition: misc.h:258
MLModelType
Definition: MLModelType.h:25
std::string to_upper(const std::string &str)
#define CHECK(condition)
Definition: Logger.h:291
bool is_regression_model(const MLModelType model_type)
Definition: MLModelType.h:69
MLModelType get_ml_model_type_from_str(const std::string &model_type_str)
Definition: MLModelType.h:52