OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
AbstractMLModel.h
Go to the documentation of this file.
1 /*
2  * Copyright 2023 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include "MLModelMetadata.h"
20 #include "MLModelType.h"
21 
22 #include <string>
23 #include "Shared/base64.h"
24 
25 namespace {
26 std::string default_metadata(const std::string& metadata) {
27  if (metadata == "" || metadata == "DEFAULT") {
28  return "{}";
29  }
30  return shared::decode_base64(metadata);
31 }
32 } // namespace
33 
35  public:
36  AbstractMLModel(const std::string& model_metadata)
37  : model_metadata_(default_metadata(model_metadata)) {}
38 
39  AbstractMLModel(const std::string& model_metadata,
40  const std::vector<std::vector<std::string>>& cat_feature_keys)
41  : model_metadata_(default_metadata(model_metadata))
42  , cat_feature_keys_(cat_feature_keys) {}
43  virtual MLModelType getModelType() const = 0;
44  virtual std::string getModelTypeString() const = 0;
45  virtual int64_t getNumFeatures() const = 0;
46  virtual ~AbstractMLModel() = default;
47  const std::string& getModelMetadataStr() const { return model_metadata_; }
49  return MLModelMetadata("",
50  getModelType(),
57  }
58  const std::vector<std::vector<std::string>>& getCatFeatureKeys() const {
59  return cat_feature_keys_;
60  }
61  const int64_t getNumCatFeatures() const { return cat_feature_keys_.size(); }
62 
63  const int64_t getNumOneHotFeatures() const {
64  int64_t num_one_hot_features{0};
65  for (const auto& cat_feature_key : cat_feature_keys_) {
66  num_one_hot_features += static_cast<int64_t>(cat_feature_key.size());
67  }
68  return num_one_hot_features;
69  }
70 
71  const int64_t getNumLogicalFeatures() const {
73  }
74 
75  protected:
76  std::string model_metadata_;
77  std::vector<std::vector<std::string>> cat_feature_keys_;
78 };
virtual MLModelType getModelType() const =0
std::string default_metadata(const std::string &metadata)
MLModelMetadata getModelMetadata() const
std::vector< std::vector< std::string > > cat_feature_keys_
MLModelType
Definition: MLModelType.h:25
const int64_t getNumCatFeatures() const
virtual ~AbstractMLModel()=default
std::string model_metadata_
virtual int64_t getNumFeatures() const =0
const std::string & getModelMetadataStr() const
const int64_t getNumLogicalFeatures() const
const int64_t getNumOneHotFeatures() const
const std::vector< std::vector< std::string > > & getCatFeatureKeys() const
std::string decode_base64(const std::string &val, bool trim_nulls)
Definition: base64.h:27
AbstractMLModel(const std::string &model_metadata, const std::vector< std::vector< std::string >> &cat_feature_keys)
AbstractMLModel(const std::string &model_metadata)
virtual std::string getModelTypeString() const =0