OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TorchWrapper.h
Go to the documentation of this file.
1 /*
2  * Copyright 2023 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #ifndef __CUDACC__
20 
23 
24 #include <string>
25 #include <vector>
26 
27 struct Detection {
28  int32_t class_idx;
29  std::string class_label;
30  double centroid_x;
31  double centroid_y;
32  double width;
33  double height;
34  float confidence;
35 };
36 
37 struct BoxDetection {
38  double tl_x;
39  double tl_y;
40  double br_x;
41  double br_y;
42  float score;
43  int class_idx;
44 };
45 
46 struct ModelInfo {
47  bool is_valid{false};
48  int64_t batch_size{-1};
49  int64_t raster_channels{-1};
50  int64_t raster_tile_width{-1};
51  int64_t raster_tile_height{-1};
52  int64_t stride{-1};
53  std::vector<std::string> class_labels;
54 };
55 
56 ModelInfo get_model_info_from_file(const std::string& filename);
57 
58 std::vector<Detection> detect_objects_in_tiled_raster(
59  const std::string& model_path,
60  const ModelInfo& model_info,
61  const bool use_gpu,
62  const int64_t device_num,
63  std::vector<float>& raster_data,
64  const RasterFormat_Namespace::RasterInfo& raster_info,
65  const float min_confidence_threshold,
66  const float iou_threshold,
67  std::shared_ptr<CpuTimer> timer);
68 
69 class TorchWarmer {
70  public:
71  static bool warmup_torch(const std::string& model_path,
72  const bool use_gpu,
73  const int64_t device_num);
74  static bool is_torch_warmed;
75 };
76 
77 #endif // __CUDACC__
int64_t stride
Definition: TorchWrapper.h:52
bool is_valid
Definition: TorchWrapper.h:47
ModelInfo get_model_info_from_file(const std::string &filename)
double width
Definition: TorchWrapper.h:32
std::vector< Detection > detect_objects_in_tiled_raster(const std::string &model_path, const ModelInfo &model_info, const bool use_gpu, const int64_t device_num, std::vector< float > &raster_data, const RasterFormat_Namespace::RasterInfo &raster_info, const float min_confidence_threshold, const float iou_threshold, std::shared_ptr< CpuTimer > timer)
std::string class_label
Definition: TorchWrapper.h:29
int32_t class_idx
Definition: TorchWrapper.h:28
std::vector< std::string > class_labels
Definition: TorchWrapper.h:53
int64_t raster_tile_width
Definition: TorchWrapper.h:50
double centroid_y
Definition: TorchWrapper.h:31
static bool warmup_torch(const std::string &model_path, const bool use_gpu, const int64_t device_num)
double centroid_x
Definition: TorchWrapper.h:30
static bool is_torch_warmed
Definition: TorchWrapper.h:74
float confidence
Definition: TorchWrapper.h:34
double height
Definition: TorchWrapper.h:33
int64_t raster_channels
Definition: TorchWrapper.h:49
int64_t batch_size
Definition: TorchWrapper.h:48
int64_t raster_tile_height
Definition: TorchWrapper.h:51