OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ShowModelDetailsCommand Class Reference

#include <DdlCommandExecutor.h>

+ Inheritance diagram for ShowModelDetailsCommand:
+ Collaboration diagram for ShowModelDetailsCommand:

Public Member Functions

 ShowModelDetailsCommand (const DdlCommandData &ddl_data, std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr)
 
ExecutionResult execute (bool read_only_mode) override
 
- Public Member Functions inherited from DdlCommand
 DdlCommand (const DdlCommandData &ddl_data, std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr)
 

Private Member Functions

std::vector< std::string > getFilteredModelNames ()
 

Additional Inherited Members

- Protected Attributes inherited from DdlCommand
const DdlCommandDataddl_data_
 
std::shared_ptr
< Catalog_Namespace::SessionInfo
const > 
session_ptr_
 

Detailed Description

Definition at line 280 of file DdlCommandExecutor.h.

Constructor & Destructor Documentation

ShowModelDetailsCommand::ShowModelDetailsCommand ( const DdlCommandData ddl_data,
std::shared_ptr< Catalog_Namespace::SessionInfo const >  session_ptr 
)

Definition at line 2094 of file DdlCommandExecutor.cpp.

References g_enable_ml_functions, and g_restrict_ml_model_metadata_to_superusers.

2097  : DdlCommand(ddl_data, session_ptr) {
2098  if (!g_enable_ml_functions) {
2099  throw std::runtime_error("Cannot show model details. ML functions are disabled.");
2100  }
2102  // Check if user is super user
2103  const auto& current_user = session_ptr->get_currentUser();
2104  if (!current_user.isSuper) {
2105  throw std::runtime_error(
2106  "Cannot show model details. Showing model information to non-superusers is "
2107  "disabled.");
2108  }
2109  }
2110 }
bool g_restrict_ml_model_metadata_to_superusers
Definition: Execute.cpp:119
bool g_enable_ml_functions
Definition: Execute.cpp:118
DdlCommand(const DdlCommandData &ddl_data, std::shared_ptr< Catalog_Namespace::SessionInfo const > session_ptr)

Member Function Documentation

ExecutionResult ShowModelDetailsCommand::execute ( bool  read_only_mode)
overridevirtual

Executes the DDL command corresponding to provided JSON payload.

Parameters
_returnresult of DDL command execution (if applicable)

Implements DdlCommand.

Definition at line 2112 of file DdlCommandExecutor.cpp.

References ResultSetLogicalValuesBuilder::create(), g_ml_models, anonymous_namespace{DdlCommandExecutor.cpp}::genLiteralBigInt(), anonymous_namespace{DdlCommandExecutor.cpp}::genLiteralDouble(), genLiteralStr(), legacylockmgr::getExecuteReadLock(), getFilteredModelNames(), MLModelMap::getModelMetadata(), kBIGINT, kDOUBLE, and kTEXT.

Referenced by heavydb.cursor.Cursor::executemany().

2112  {
2113  auto execute_read_lock = legacylockmgr::getExecuteReadLock();
2114 
2115  std::vector<TargetMetaInfo> label_infos;
2116  label_infos.emplace_back("model_name", SQLTypeInfo(kTEXT, true));
2117  label_infos.emplace_back("model_type", SQLTypeInfo(kTEXT, true));
2118  label_infos.emplace_back("predicted", SQLTypeInfo(kTEXT, true));
2119  label_infos.emplace_back("features", SQLTypeInfo(kTEXT, true));
2120  label_infos.emplace_back("training_query", SQLTypeInfo(kTEXT, true));
2121  label_infos.emplace_back("num_logical_features", SQLTypeInfo(kBIGINT, true));
2122  label_infos.emplace_back("num_physical_features", SQLTypeInfo(kBIGINT, true));
2123  label_infos.emplace_back("num_categorical_features", SQLTypeInfo(kBIGINT, true));
2124  label_infos.emplace_back("num_numeric_features", SQLTypeInfo(kBIGINT, true));
2125  label_infos.emplace_back("train_fraction", SQLTypeInfo(kDOUBLE, true));
2126  label_infos.emplace_back("eval_fraction", SQLTypeInfo(kDOUBLE, true));
2127 
2128  // Get all model names
2129  const auto model_names = getFilteredModelNames();
2130 
2131  // logical_values -> table data
2132  std::vector<RelLogicalValues::RowValues> logical_values;
2133  for (auto& model_name : model_names) {
2134  logical_values.emplace_back(RelLogicalValues::RowValues{});
2135  logical_values.back().emplace_back(genLiteralStr(model_name));
2136  const auto model_metadata = g_ml_models.getModelMetadata(model_name);
2137  logical_values.back().emplace_back(genLiteralStr(model_metadata.getModelTypeStr()));
2138  logical_values.back().emplace_back(genLiteralStr(model_metadata.getPredicted()));
2139  const auto& features = model_metadata.getFeatures();
2140  std::ostringstream features_oss;
2141  bool is_first_feature = true;
2142  for (auto& feature : features) {
2143  if (!is_first_feature) {
2144  features_oss << ", ";
2145  } else {
2146  is_first_feature = false;
2147  }
2148  features_oss << feature;
2149  }
2150  auto features_str = features_oss.str();
2151  logical_values.back().emplace_back(genLiteralStr(features_str));
2152  logical_values.back().emplace_back(genLiteralStr(model_metadata.getTrainingQuery()));
2153  logical_values.back().emplace_back(
2154  genLiteralBigInt(model_metadata.getNumLogicalFeatures()));
2155  logical_values.back().emplace_back(genLiteralBigInt(model_metadata.getNumFeatures()));
2156  logical_values.back().emplace_back(
2157  genLiteralBigInt(model_metadata.getNumCategoricalFeatures()));
2158  logical_values.back().emplace_back(
2159  genLiteralBigInt(model_metadata.getNumLogicalFeatures() -
2160  model_metadata.getNumCategoricalFeatures()));
2161  logical_values.back().emplace_back(
2162  genLiteralDouble(model_metadata.getDataSplitTrainFraction()));
2163  logical_values.back().emplace_back(
2164  genLiteralDouble(model_metadata.getDataSplitEvalFraction()));
2165  }
2166 
2167  // Create ResultSet
2168  std::shared_ptr<ResultSet> rSet = std::shared_ptr<ResultSet>(
2169  ResultSetLogicalValuesBuilder::create(label_infos, logical_values));
2170 
2171  return ExecutionResult(rSet, label_infos);
2172 }
std::vector< std::string > getFilteredModelNames()
auto getExecuteReadLock()
std::unique_ptr< RexLiteral > genLiteralDouble(double val)
std::vector< MLModelMetadata > getModelMetadata() const
Definition: MLModel.h:83
std::unique_ptr< RexLiteral > genLiteralBigInt(int64_t val)
static ResultSet * create(std::vector< TargetMetaInfo > &label_infos, std::vector< RelLogicalValues::RowValues > &logical_values)
MLModelMap g_ml_models
Definition: MLModel.h:124
Definition: sqltypes.h:79
static std::unique_ptr< RexLiteral > genLiteralStr(std::string val)
Definition: DBHandler.cpp:7752
std::vector< std::unique_ptr< const RexScalar >> RowValues
Definition: RelAlgDag.h:2656

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

std::vector< std::string > ShowModelDetailsCommand::getFilteredModelNames ( )
private

Definition at line 2174 of file DdlCommandExecutor.cpp.

References DdlCommand::ddl_data_, anonymous_namespace{DdlCommandExecutor.cpp}::extractPayload(), g_ml_models, MLModelMap::getModelNames(), and to_upper().

Referenced by execute().

2174  {
2175  auto& ddl_payload = extractPayload(ddl_data_);
2176  auto all_model_names = g_ml_models.getModelNames();
2177  if (ddl_payload.HasMember("modelNames")) {
2178  std::vector<std::string> filtered_model_names;
2179  std::set<std::string> all_model_names_set(all_model_names.begin(),
2180  all_model_names.end());
2181  for (const auto& model_name_json : ddl_payload["modelNames"].GetArray()) {
2182  std::string model_name = model_name_json.GetString();
2183  const auto model_name_upper = to_upper(model_name);
2184  if (all_model_names_set.find(to_upper(model_name_upper)) ==
2185  all_model_names_set.end()) {
2186  throw std::runtime_error{"Unable to show model details for model: " +
2187  model_name_upper + ". Model does not exist."};
2188  }
2189  filtered_model_names.emplace_back(model_name_upper);
2190  }
2191  return filtered_model_names;
2192  } else {
2193  return all_model_names;
2194  }
2195 }
const DdlCommandData & ddl_data_
const rapidjson::Value & extractPayload(const DdlCommandData &ddl_data)
std::string to_upper(const std::string &str)
MLModelMap g_ml_models
Definition: MLModel.h:124
std::vector< std::string > getModelNames() const
Definition: MLModel.h:74

+ Here is the call graph for this function:

+ Here is the caller graph for this function:


The documentation for this class was generated from the following files: