OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MLModelMetadata.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "MLModelMetadata.h"
18 
19 #include <rapidjson/document.h>
20 #include <rapidjson/stringbuffer.h>
21 #include <rapidjson/writer.h>
22 
23 void MLModelMetadata::extractModelMetadata(const std::string& model_metadata_json,
24  const int64_t num_logical_features) {
25  rapidjson::Document model_metadata_doc;
26  model_metadata_doc.Parse(model_metadata_json.c_str());
27  if (model_metadata_doc.HasMember("predicted") &&
28  model_metadata_doc["predicted"].IsString()) {
29  predicted_ = model_metadata_doc["predicted"].GetString();
30  }
31  if (model_metadata_doc.HasMember("training_query") &&
32  model_metadata_doc["training_query"].IsString()) {
33  training_query_ = model_metadata_doc["training_query"].GetString();
34  }
35  if (model_metadata_doc.HasMember("features") &&
36  model_metadata_doc["features"].IsArray()) {
37  const rapidjson::Value& features_array = model_metadata_doc["features"];
38  for (const auto& feature : features_array.GetArray()) {
39  features_.emplace_back(feature.GetString());
40  }
41  } else {
42  features_.resize(num_logical_features, "");
43  }
44  if (model_metadata_doc.HasMember("data_split_train_fraction") &&
45  model_metadata_doc["data_split_train_fraction"].IsDouble()) {
46  // Extract the double value
48  model_metadata_doc["data_split_train_fraction"].GetDouble();
49  }
50  if (model_metadata_doc.HasMember("data_split_eval_fraction") &&
51  model_metadata_doc["data_split_eval_fraction"].IsDouble()) {
52  // Extract the double value
54  model_metadata_doc["data_split_eval_fraction"].GetDouble();
55  }
56  if (model_metadata_doc.HasMember("feature_permutations") &&
57  model_metadata_doc["feature_permutations"].IsArray()) {
58  const rapidjson::Value& feature_permutations_array =
59  model_metadata_doc["feature_permutations"];
60  for (const auto& feature_permutation : feature_permutations_array.GetArray()) {
61  feature_permutations_.emplace_back(feature_permutation.GetInt64());
62  }
63  }
64 }
double data_split_eval_fraction_
std::string training_query_
double data_split_train_fraction_
void extractModelMetadata(const std::string &model_metadata_json, const int64_t num_logical_features)
std::vector< std::string > features_
std::vector< int64_t > feature_permutations_
std::string predicted_