OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TorchTableFunctions.h File Reference
+ Include dependency graph for TorchTableFunctions.h:
+ This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

template<typename PixelType , typename ColorType >
TEMPLATE_NOINLINE int32_t tf_torch_raster_obj_detect__cpu_template (TableFunctionManager &mgr, const Column< PixelType > &input_x, const Column< PixelType > &input_y, const ColumnList< ColorType > &input_channels, const PixelType x_input_units_per_pixel, const PixelType y_input_units_per_pixel, const float max_color_value, const int64_t tile_boundary_halo_pixels, const TextEncodingNone &model_path, const TextEncodingNone &model_metadata_path, const float min_confidence_threshold, const float iou_threshold, const bool use_gpu, const int64_t device_num, Column< TextEncodingDict > &detected_class_label, Column< int32_t > &detected_class_id, Column< double > &detected_centroid_x, Column< double > &detected_centroid_y, Column< double > &detected_width, Column< double > &detected_height, Column< float > &detected_confidence)
 

Function Documentation

template<typename PixelType , typename ColorType >
TEMPLATE_NOINLINE int32_t tf_torch_raster_obj_detect__cpu_template ( TableFunctionManager mgr,
const Column< PixelType > &  input_x,
const Column< PixelType > &  input_y,
const ColumnList< ColorType > &  input_channels,
const PixelType  x_input_units_per_pixel,
const PixelType  y_input_units_per_pixel,
const float  max_color_value,
const int64_t  tile_boundary_halo_pixels,
const TextEncodingNone model_path,
const TextEncodingNone model_metadata_path,
const float  min_confidence_threshold,
const float  iou_threshold,
const bool  use_gpu,
const int64_t  device_num,
Column< TextEncodingDict > &  detected_class_label,
Column< int32_t > &  detected_class_id,
Column< double > &  detected_centroid_x,
Column< double > &  detected_centroid_y,
Column< double > &  detected_width,
Column< double > &  detected_height,
Column< float > &  detected_confidence 
)

Definition at line 28 of file TorchTableFunctions.cpp.

References class_idx, detect_objects_in_tiled_raster(), get_model_info_from_file(), StringDictionaryProxy::getOrAddTransientBulk(), TextEncodingNone::getString(), TableFunctionManager::set_output_row_size(), Column< TextEncodingDict >::setNull(), and Column< TextEncodingDict >::string_dict_proxy_.

48  {
49  std::shared_ptr<CpuTimer> timer =
50  std::make_shared<CpuTimer>("tf_torch_raster_obj_detect");
51 
52  timer->start_event_timer("get_model_info_from_file");
53  const std::string model_path_str(model_path.getString());
54  std::string model_metadata_path_str(model_metadata_path.getString());
55  if (model_metadata_path_str.empty()) {
56  model_metadata_path_str = model_path_str;
57  }
58  const auto model_info = get_model_info_from_file(model_metadata_path_str);
59  if (!model_info.is_valid) {
60  return mgr.ERROR_MESSAGE("Could not get model info from file.");
61  }
62  if (model_info.class_labels.empty()) {
63  return mgr.ERROR_MESSAGE("Could not get class labels from file.");
64  }
65  const auto class_idx_to_label_vec =
66  detected_class_label.string_dict_proxy_->getOrAddTransientBulk(
67  model_info.class_labels);
68  const int64_t num_class_labels = static_cast<int64_t>(class_idx_to_label_vec.size());
69 
70  constexpr int64_t target_batch_size_multiple{8};
71 
72  auto raster_data =
73  RasterFormat_Namespace::format_raster_data<PixelType, ColorType, float>(
74  input_x,
75  input_y,
76  input_channels,
77  x_input_units_per_pixel,
78  y_input_units_per_pixel,
79  max_color_value,
80  model_info.raster_tile_width,
81  model_info.raster_tile_height,
82  tile_boundary_halo_pixels,
83  target_batch_size_multiple,
84  timer->start_nested_event_timer("format_raster_data"));
85 
86  try {
87  const auto processed_detections = detect_objects_in_tiled_raster(
88  model_path_str,
89  model_info,
90  use_gpu,
91  device_num,
92  raster_data.first,
93  raster_data.second,
94  min_confidence_threshold,
95  iou_threshold,
96  timer->start_nested_event_timer("detect_objects_in_tiled_raster"));
97  timer->start_event_timer("Write results");
98  const int64_t num_detections = processed_detections.size();
99  mgr.set_output_row_size(num_detections);
100  // The class labels taken from the model file will be in same order as class idxs by
101  // definition
102  for (int64_t detection_idx = 0; detection_idx < num_detections; ++detection_idx) {
103  const auto class_idx = processed_detections[detection_idx].class_idx;
104  detected_class_id[detection_idx] = class_idx;
105  if (class_idx < 0 || class_idx >= num_class_labels) {
106  detected_class_label.setNull(detection_idx);
107  } else {
108  detected_class_label[detection_idx] = class_idx_to_label_vec[class_idx];
109  }
110  detected_centroid_x[detection_idx] = processed_detections[detection_idx].centroid_x;
111  detected_centroid_y[detection_idx] = processed_detections[detection_idx].centroid_y;
112  detected_width[detection_idx] = processed_detections[detection_idx].width;
113  detected_height[detection_idx] = processed_detections[detection_idx].height;
114  detected_confidence[detection_idx] = processed_detections[detection_idx].confidence;
115  }
116  return num_detections;
117  } catch (const std::runtime_error& e) {
118  return mgr.ERROR_MESSAGE(e.what());
119  }
120  return 0;
121 }
void set_output_row_size(int64_t num_rows)
Definition: heavydbTypes.h:373
ModelInfo get_model_info_from_file(const std::string &filename)
std::string getString() const
Definition: heavydbTypes.h:641
std::vector< Detection > detect_objects_in_tiled_raster(const std::string &model_path, const ModelInfo &model_info, const bool use_gpu, const int64_t device_num, std::vector< float > &raster_data, const RasterFormat_Namespace::RasterInfo &raster_info, const float min_confidence_threshold, const float iou_threshold, std::shared_ptr< CpuTimer > timer)
DEVICE void setNull(int64_t index)
StringDictionaryProxy * string_dict_proxy_
std::vector< int32_t > getOrAddTransientBulk(const std::vector< std::string > &strings)

+ Here is the call graph for this function: