OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TorchTableFunctions.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 HEAVY.AI, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef __CUDACC__
18 
19 #include "TorchTableFunctions.h"
23 
24 #include <cstdio>
25 
26 template <typename PixelType, typename ColorType>
27 TEMPLATE_NOINLINE int32_t
29  const Column<PixelType>& input_x,
30  const Column<PixelType>& input_y,
31  const ColumnList<ColorType>& input_channels,
32  const PixelType x_input_units_per_pixel,
33  const PixelType y_input_units_per_pixel,
34  const float max_color_value,
35  const int64_t tile_boundary_halo_pixels,
36  const TextEncodingNone& model_path,
37  const TextEncodingNone& model_metadata_path,
38  const float min_confidence_threshold,
39  const float iou_threshold,
40  const bool use_gpu,
41  const int64_t device_num,
42  Column<TextEncodingDict>& detected_class_label,
43  Column<int32_t>& detected_class_id,
44  Column<double>& detected_centroid_x,
45  Column<double>& detected_centroid_y,
46  Column<double>& detected_width,
47  Column<double>& detected_height,
48  Column<float>& detected_confidence) {
49  std::shared_ptr<CpuTimer> timer =
50  std::make_shared<CpuTimer>("tf_torch_raster_obj_detect");
51 
52  timer->start_event_timer("get_model_info_from_file");
53  const std::string model_path_str(model_path.getString());
54  std::string model_metadata_path_str(model_metadata_path.getString());
55  if (model_metadata_path_str.empty()) {
56  model_metadata_path_str = model_path_str;
57  }
58  const auto model_info = get_model_info_from_file(model_metadata_path_str);
59  if (!model_info.is_valid) {
60  return mgr.ERROR_MESSAGE("Could not get model info from file.");
61  }
62  if (model_info.class_labels.empty()) {
63  return mgr.ERROR_MESSAGE("Could not get class labels from file.");
64  }
65  const auto class_idx_to_label_vec =
66  detected_class_label.string_dict_proxy_->getOrAddTransientBulk(
67  model_info.class_labels);
68  const int64_t num_class_labels = static_cast<int64_t>(class_idx_to_label_vec.size());
69 
70  constexpr int64_t target_batch_size_multiple{8};
71 
72  auto raster_data =
73  RasterFormat_Namespace::format_raster_data<PixelType, ColorType, float>(
74  input_x,
75  input_y,
76  input_channels,
77  x_input_units_per_pixel,
78  y_input_units_per_pixel,
79  max_color_value,
80  model_info.raster_tile_width,
81  model_info.raster_tile_height,
82  tile_boundary_halo_pixels,
83  target_batch_size_multiple,
84  timer->start_nested_event_timer("format_raster_data"));
85 
86  try {
87  const auto processed_detections = detect_objects_in_tiled_raster(
88  model_path_str,
89  model_info,
90  use_gpu,
91  device_num,
92  raster_data.first,
93  raster_data.second,
94  min_confidence_threshold,
95  iou_threshold,
96  timer->start_nested_event_timer("detect_objects_in_tiled_raster"));
97  timer->start_event_timer("Write results");
98  const int64_t num_detections = processed_detections.size();
99  mgr.set_output_row_size(num_detections);
100  // The class labels taken from the model file will be in same order as class idxs by
101  // definition
102  for (int64_t detection_idx = 0; detection_idx < num_detections; ++detection_idx) {
103  const auto class_idx = processed_detections[detection_idx].class_idx;
104  detected_class_id[detection_idx] = class_idx;
105  if (class_idx < 0 || class_idx >= num_class_labels) {
106  detected_class_label.setNull(detection_idx);
107  } else {
108  detected_class_label[detection_idx] = class_idx_to_label_vec[class_idx];
109  }
110  detected_centroid_x[detection_idx] = processed_detections[detection_idx].centroid_x;
111  detected_centroid_y[detection_idx] = processed_detections[detection_idx].centroid_y;
112  detected_width[detection_idx] = processed_detections[detection_idx].width;
113  detected_height[detection_idx] = processed_detections[detection_idx].height;
114  detected_confidence[detection_idx] = processed_detections[detection_idx].confidence;
115  }
116  return num_detections;
117  } catch (const std::runtime_error& e) {
118  return mgr.ERROR_MESSAGE(e.what());
119  }
120  return 0;
121 }
122 
123 template TEMPLATE_NOINLINE int32_t
125  const Column<float>& input_x,
126  const Column<float>& input_y,
127  const ColumnList<int16_t>& input_channels,
128  const float x_input_units_per_pixel,
129  const float y_input_units_per_pixel,
130  const float max_color_value,
131  const int64_t tile_boundary_halo_pixels,
132  const TextEncodingNone& model_path,
133  const TextEncodingNone& model_metadata_path,
134  const float min_confidence_threshold,
135  const float iou_threshold,
136  const bool use_gpu,
137  const int64_t device_num,
138  Column<TextEncodingDict>& detected_class_label,
139  Column<int32_t>& detected_class_id,
140  Column<double>& detected_centroid_x,
141  Column<double>& detected_centroid_y,
142  Column<double>& detected_width,
143  Column<double>& detected_height,
144  Column<float>& detected_confidence);
145 
146 template TEMPLATE_NOINLINE int32_t
148  const Column<float>& input_x,
149  const Column<float>& input_y,
150  const ColumnList<int32_t>& input_channels,
151  const float x_input_units_per_pixel,
152  const float y_input_units_per_pixel,
153  const float max_color_value,
154  const int64_t tile_boundary_halo_pixels,
155  const TextEncodingNone& model_path,
156  const TextEncodingNone& model_metadata_path,
157  const float min_confidence_threshold,
158  const float iou_threshold,
159  const bool use_gpu,
160  const int64_t device_num,
161  Column<TextEncodingDict>& detected_class_label,
162  Column<int32_t>& detected_class_id,
163  Column<double>& detected_centroid_x,
164  Column<double>& detected_centroid_y,
165  Column<double>& detected_width,
166  Column<double>& detected_height,
167  Column<float>& detected_confidence);
168 
169 template TEMPLATE_NOINLINE int32_t
171  const Column<double>& input_x,
172  const Column<double>& input_y,
173  const ColumnList<int16_t>& input_channels,
174  const double x_input_units_per_pixel,
175  const double y_input_units_per_pixel,
176  const float max_color_value,
177  const int64_t tile_boundary_halo_pixels,
178  const TextEncodingNone& model_path,
179  const TextEncodingNone& model_metadata_path,
180  const float min_confidence_threshold,
181  const float iou_threshold,
182  const bool use_gpu,
183  const int64_t device_num,
184  Column<TextEncodingDict>& detected_class_label,
185  Column<int32_t>& detected_class_id,
186  Column<double>& detected_centroid_x,
187  Column<double>& detected_centroid_y,
188  Column<double>& detected_width,
189  Column<double>& detected_height,
190  Column<float>& detected_confidence);
191 
192 template TEMPLATE_NOINLINE int32_t
194  const Column<double>& input_x,
195  const Column<double>& input_y,
196  const ColumnList<int32_t>& input_channels,
197  const double x_input_units_per_pixel,
198  const double y_input_units_per_pixel,
199  const float max_color_value,
200  const int64_t tile_boundary_halo_pixels,
201  const TextEncodingNone& model_path,
202  const TextEncodingNone& model_metadata_path,
203  const float min_confidence_threshold,
204  const float iou_threshold,
205  const bool use_gpu,
206  const int64_t device_num,
207  Column<TextEncodingDict>& detected_class_label,
208  Column<int32_t>& detected_class_id,
209  Column<double>& detected_centroid_x,
210  Column<double>& detected_centroid_y,
211  Column<double>& detected_width,
212  Column<double>& detected_height,
213  Column<float>& detected_confidence);
214 
215 #endif
void set_output_row_size(int64_t num_rows)
Definition: heavydbTypes.h:373
ModelInfo get_model_info_from_file(const std::string &filename)
std::string getString() const
Definition: heavydbTypes.h:641
std::vector< Detection > detect_objects_in_tiled_raster(const std::string &model_path, const ModelInfo &model_info, const bool use_gpu, const int64_t device_num, std::vector< float > &raster_data, const RasterFormat_Namespace::RasterInfo &raster_info, const float min_confidence_threshold, const float iou_threshold, std::shared_ptr< CpuTimer > timer)
DEVICE void setNull(int64_t index)
StringDictionaryProxy * string_dict_proxy_
TEMPLATE_NOINLINE int32_t tf_torch_raster_obj_detect__cpu_template(TableFunctionManager &mgr, const Column< PixelType > &input_x, const Column< PixelType > &input_y, const ColumnList< ColorType > &input_channels, const PixelType x_input_units_per_pixel, const PixelType y_input_units_per_pixel, const float max_color_value, const int64_t tile_boundary_halo_pixels, const TextEncodingNone &model_path, const TextEncodingNone &model_metadata_path, const float min_confidence_threshold, const float iou_threshold, const bool use_gpu, const int64_t device_num, Column< TextEncodingDict > &detected_class_label, Column< int32_t > &detected_class_id, Column< double > &detected_centroid_x, Column< double > &detected_centroid_y, Column< double > &detected_width, Column< double > &detected_height, Column< float > &detected_confidence)
std::vector< int32_t > getOrAddTransientBulk(const std::vector< std::string > &strings)
#define TEMPLATE_NOINLINE
Definition: heavydbTypes.h:60