OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TorchWrapper.cpp File Reference
#include "TorchWrapper.h"
#include "Shared/funcannotations.h"
#include "TorchOps.hpp"
#include <torch/script.h>
#include <torch/torch.h>
#include <chrono>
#include <cmath>
#include <fstream>
#include <iostream>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "rapidjson/document.h"
+ Include dependency graph for TorchWrapper.cpp:

Go to the source code of this file.

Enumerations

enum  DetectionIdx {
  centroid_x = 0, centroid_y = 1, width = 2, height = 3,
  class_idx = 4, score = 5
}
 
enum  BoxDetectionIdx { tl_x = 0, tl_y = 1, br_x = 2, br_y = 3 }
 

Functions

std::string get_device_string (const bool use_gpu, const int64_t device_num)
 
bool should_use_half (const bool use_gpu, const std::string &model_path)
 
std::shared_ptr
< torch::jit::script::Module > 
get_model_from_cache (const std::string &model_path)
 
void add_model_to_cache (const std::string &model_path, std::shared_ptr< torch::jit::script::Module > model_module)
 
std::shared_ptr
< torch::jit::script::Module > 
load_module (const std::string &model_path, const std::string compute_device, const at::ScalarType data_type, const bool use_cache)
 
std::string get_json_str_from_file_header (const std::string &filename, const size_t max_search_chars)
 
ModelInfo get_model_info_from_json (const std::string &json_str)
 
torch::Tensor xywh2xyxy (const torch::Tensor &x)
 
torch::Tensor world_scale_detections (const torch::Tensor &input, const int64_t batch_idx, const RasterFormat_Namespace::RasterInfo &raster_info)
 
std::vector< Detectionprocess_detections (const torch::Tensor &raw_detections, const float min_confidence_threshold, const float iou_threshold, const ModelInfo &model_info, const RasterFormat_Namespace::RasterInfo &raster_info, std::shared_ptr< CpuTimer > timer)
 
std::vector< Detectiondetect_objects_in_tiled_raster_impl (const std::string &model_path, const ModelInfo &model_info, const bool use_gpu, const int64_t device_num, std::vector< float > &raster_data, const RasterFormat_Namespace::RasterInfo &raster_info, const float min_confidence_threshold, const float iou_threshold, std::shared_ptr< CpuTimer > timer)
 
void print_model_params (const std::string &model_path, const bool use_gpu, const int64_t device_num)
 
 __attribute__ ((__used__)) ModelInfo get_model_info_from_file(const std
 

Variables

static std::unordered_map
< std::string, std::shared_ptr
< torch::jit::script::Module > > 
model_cache
 
static std::shared_mutex model_mutex
 

Enumeration Type Documentation

Enumerator
tl_x 
tl_y 
br_x 
br_y 

Definition at line 188 of file TorchWrapper.cpp.

188  {
189  tl_x = 0,
190  tl_y = 1,
191  br_x = 2,
192  br_y = 3,
193 };
Enumerator
centroid_x 
centroid_y 
width 
height 
class_idx 
score 

Definition at line 179 of file TorchWrapper.cpp.

179  {
180  centroid_x = 0,
181  centroid_y = 1,
182  width = 2,
183  height = 3,
184  class_idx = 4,
185  score = 5
186 };

Function Documentation

__attribute__ ( (__used__)  ) const

Definition at line 469 of file TorchWrapper.cpp.

References get_json_str_from_file_header(), get_model_info_from_json(), ModelInfo::is_valid, and json_str().

470  {
471  const std::string json_str =
472  get_json_str_from_file_header(filename, 100 /* max_search_chars */);
473  if (json_str.size() > 0) {
474  const ModelInfo model_info = get_model_info_from_json(json_str);
475  if (model_info.is_valid) {
476  return model_info;
477  }
478  }
479  return {};
480 }
std::string get_json_str_from_file_header(const std::string &filename, const size_t max_search_chars)
ModelInfo get_model_info_from_json(const std::string &json_str)
bool is_valid
Definition: TorchWrapper.h:47
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:46

+ Here is the call graph for this function:

void add_model_to_cache ( const std::string &  model_path,
std::shared_ptr< torch::jit::script::Module >  model_module 
)

Definition at line 76 of file TorchWrapper.cpp.

References model_cache, and model_mutex.

Referenced by load_module().

77  {
78  std::unique_lock<std::shared_mutex> model_cache_write_lock(model_mutex);
79  model_cache.emplace(model_path, model_module);
80 }
static std::unordered_map< std::string, std::shared_ptr< torch::jit::script::Module > > model_cache
static std::shared_mutex model_mutex

+ Here is the caller graph for this function:

std::vector<Detection> detect_objects_in_tiled_raster_impl ( const std::string &  model_path,
const ModelInfo model_info,
const bool  use_gpu,
const int64_t  device_num,
std::vector< float > &  raster_data,
const RasterFormat_Namespace::RasterInfo raster_info,
const float  min_confidence_threshold,
const float  iou_threshold,
std::shared_ptr< CpuTimer timer 
)

Definition at line 354 of file TorchWrapper.cpp.

References RasterFormat_Namespace::RasterInfo::batch_tiles, get_device_string(), load_module(), process_detections(), RasterFormat_Namespace::RasterInfo::raster_channels, should_use_half(), RasterFormat_Namespace::RasterInfo::x_pixels_per_tile, RasterFormat_Namespace::RasterInfo::x_tiles, RasterFormat_Namespace::RasterInfo::y_pixels_per_tile, and RasterFormat_Namespace::RasterInfo::y_tiles.

363  {
364  const std::string compute_device = get_device_string(use_gpu, device_num);
365  const bool use_half = should_use_half(use_gpu, model_path);
366  const auto input_data_type = use_half ? torch::kHalf : torch::kFloat32;
367  const bool use_model_cache = use_gpu;
368 
369  try {
370  // Moved try block up to here as InferenceMode call below can throw if GPU is not
371  // specified correctly
372 #ifdef HAVE_CUDA_TORCH
373  c10::cuda::OptionalCUDAGuard cuda_guard;
374  if (use_gpu) {
375  cuda_guard.set_index(static_cast<int8_t>(device_num));
376  }
377 #endif
378 
379  c10::InferenceMode guard;
380  torch::NoGradGuard no_grad;
381 
382  timer->start_event_timer("Model load");
383 
384  auto module =
385  load_module(model_path, compute_device, input_data_type, use_model_cache);
386  timer->start_event_timer("Input prep");
387  std::cout << "Device: " << compute_device << " Use half: " << use_half << std::endl;
388 
389  std::cout << "X tiles: " << raster_info.x_tiles << " Y tiles: " << raster_info.y_tiles
390  << " Batch size: " << raster_info.batch_tiles << std::endl;
391 
392  auto input_tensor =
393  torch::from_blob(raster_data.data(),
394  {raster_info.batch_tiles,
395  raster_info.raster_channels,
396  raster_info.y_pixels_per_tile,
397  raster_info.x_pixels_per_tile} /*, tensor_options */)
398  .to(compute_device, input_data_type);
399 
400  std::vector<torch::jit::IValue> module_input;
401  module_input.emplace_back(input_tensor);
402 
403  timer->start_event_timer("Inference");
404  torch::jit::IValue output = module->forward(module_input);
405 
406  auto raw_detections = output.toTuple()->elements()[0].toTensor();
407 
408 #ifdef HAVE_CUDA_TORCH
409  constexpr bool enable_debug_timing{true};
410  if (enable_debug_timing && use_gpu) {
411  std::cout << "Synchronizing timing" << std::endl;
412  c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();
413  AT_CUDA_CHECK(cudaStreamSynchronize(stream));
414  }
415 #endif
416 
417  const auto processed_detections =
418  process_detections(raw_detections,
419  min_confidence_threshold,
420  iou_threshold,
421  model_info,
422  raster_info,
423  timer->start_nested_event_timer("process_detections"));
424 
425  return processed_detections;
426 
427  } catch (std::exception& e) {
428  std::string error_msg{"Error during model inference: "};
429  error_msg += e.what();
430  throw std::runtime_error(error_msg);
431  }
432 }
std::vector< Detection > process_detections(const torch::Tensor &raw_detections, const float min_confidence_threshold, const float iou_threshold, const ModelInfo &model_info, const RasterFormat_Namespace::RasterInfo &raster_info, std::shared_ptr< CpuTimer > timer)
std::string get_device_string(const bool use_gpu, const int64_t device_num)
bool should_use_half(const bool use_gpu, const std::string &model_path)
std::shared_ptr< torch::jit::script::Module > load_module(const std::string &model_path, const std::string compute_device, const at::ScalarType data_type, const bool use_cache)

+ Here is the call graph for this function:

std::string get_device_string ( const bool  use_gpu,
const int64_t  device_num 
)

Definition at line 42 of file TorchWrapper.cpp.

References to_string().

Referenced by detect_objects_in_tiled_raster_impl(), and print_model_params().

42  {
43  std::string device_type{"cpu"};
44 #ifdef HAVE_CUDA_TORCH
45  if (torch::cuda::is_available() && use_gpu) {
46  device_type = "cuda:" + std::to_string(device_num);
47  }
48 #endif
49  return device_type;
50 }
std::string to_string(char const *&&v)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

std::string get_json_str_from_file_header ( const std::string &  filename,
const size_t  max_search_chars 
)

Definition at line 112 of file TorchWrapper.cpp.

References json_str().

Referenced by __attribute__().

113  {
114  std::ifstream model_file(filename);
115  bool found_opening_brace = false;
116  size_t brace_nest_count = 0;
117  size_t char_idx{0};
118  std::string json_str;
119  if (model_file.is_open()) {
120  char c;
121  while (model_file.get(c) && (brace_nest_count >= 1 || char_idx < max_search_chars)) {
122  char_idx++;
123  if (c == '{') {
124  found_opening_brace = true;
125  brace_nest_count++;
126  }
127  if (found_opening_brace) {
128  json_str += c;
129  }
130  if (c == '}') {
131  if (brace_nest_count > 0) {
132  brace_nest_count--;
133  if (found_opening_brace &&
134  brace_nest_count == 0) { // found_opening_brace superfluous
135  break;
136  }
137  }
138  }
139  }
140  }
141  if (found_opening_brace && brace_nest_count == 0) {
142  return json_str;
143  }
144  return "";
145 }
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:46

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

std::shared_ptr<torch::jit::script::Module> get_model_from_cache ( const std::string &  model_path)

Definition at line 66 of file TorchWrapper.cpp.

References model_cache, and model_mutex.

Referenced by load_module().

67  {
68  std::shared_lock<std::shared_mutex> model_cache_read_lock(model_mutex);
69  auto model_itr = model_cache.find(model_path);
70  if (model_itr == model_cache.end()) {
71  return nullptr;
72  }
73  return model_itr->second;
74 }
static std::unordered_map< std::string, std::shared_ptr< torch::jit::script::Module > > model_cache
static std::shared_mutex model_mutex

+ Here is the caller graph for this function:

ModelInfo get_model_info_from_json ( const std::string &  json_str)

Definition at line 147 of file TorchWrapper.cpp.

References ModelInfo::batch_size, ModelInfo::class_labels, ModelInfo::is_valid, run_benchmark_import::label, ModelInfo::raster_channels, ModelInfo::raster_tile_height, ModelInfo::raster_tile_width, and ModelInfo::stride.

Referenced by __attribute__().

147  {
148  ModelInfo model_info;
149  rapidjson::Document doc;
150  if (doc.Parse<0>(json_str.c_str()).HasParseError()) {
151  return model_info; // will have is_valid set to false
152  }
153  const auto shape_array_itr = doc.FindMember("shape");
154  if (shape_array_itr != doc.MemberEnd() && shape_array_itr->value.IsArray()) {
155  const rapidjson::SizeType num_shape_elems = shape_array_itr->value.Size();
156  if (num_shape_elems == 4) {
157  model_info.batch_size = shape_array_itr->value[0].GetInt();
158  model_info.raster_channels = shape_array_itr->value[1].GetInt();
159  model_info.raster_tile_height = shape_array_itr->value[2].GetInt();
160  model_info.raster_tile_width = shape_array_itr->value[3].GetInt();
161  }
162  }
163  const auto stride_itr = doc.FindMember("stride");
164  if (stride_itr != doc.MemberEnd() && stride_itr->value.IsInt()) {
165  model_info.stride = stride_itr->value.GetInt();
166  }
167  const auto class_labels_itr = doc.FindMember("names");
168  if (class_labels_itr != doc.MemberEnd() && class_labels_itr->value.IsArray()) {
169  const rapidjson::SizeType num_class_labels = class_labels_itr->value.Size();
170  model_info.class_labels.reserve(static_cast<size_t>(num_class_labels));
171  for (auto& label : class_labels_itr->value.GetArray()) {
172  model_info.class_labels.emplace_back(label.GetString());
173  }
174  }
175  model_info.is_valid = true;
176  return model_info;
177 }
int64_t stride
Definition: TorchWrapper.h:52
bool is_valid
Definition: TorchWrapper.h:47
const std::string json_str(const rapidjson::Value &obj) noexcept
Definition: JsonAccessors.h:46
std::vector< std::string > class_labels
Definition: TorchWrapper.h:53
int64_t raster_tile_width
Definition: TorchWrapper.h:50
int64_t raster_channels
Definition: TorchWrapper.h:49
int64_t batch_size
Definition: TorchWrapper.h:48
int64_t raster_tile_height
Definition: TorchWrapper.h:51

+ Here is the caller graph for this function:

std::shared_ptr<torch::jit::script::Module> load_module ( const std::string &  model_path,
const std::string  compute_device,
const at::ScalarType  data_type,
const bool  use_cache 
)

Definition at line 82 of file TorchWrapper.cpp.

References add_model_to_cache(), get_model_from_cache(), and boost::serialization::load().

Referenced by detect_objects_in_tiled_raster_impl(), and print_model_params().

86  {
87  std::shared_ptr<torch::jit::script::Module> module;
88  try {
89  // Deserialize the ScriptModule from a file using torch::jit::load().
90  if (use_cache) {
91  module = get_model_from_cache(model_path);
92  }
93  if (module == nullptr) { // module not found or not using cache
94  module = std::make_shared<torch::jit::script::Module>(torch::jit::load(model_path));
95  module->to(compute_device, data_type);
96  module->eval();
97 
98  if (use_cache) {
99  add_model_to_cache(model_path, module);
100  }
101  } else {
102  module->eval();
103  }
104  } catch (const c10::Error& e) {
105  std::string error_msg{"Error loading the provided model: "};
106  error_msg += e.what();
107  throw std::runtime_error(error_msg);
108  }
109  return module;
110 }
std::shared_ptr< torch::jit::script::Module > get_model_from_cache(const std::string &model_path)
void load(Archive &ar, ExplainedQueryHint &query_hint, const unsigned int version)
void add_model_to_cache(const std::string &model_path, std::shared_ptr< torch::jit::script::Module > model_module)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

void print_model_params ( const std::string &  model_path,
const bool  use_gpu,
const int64_t  device_num 
)

Definition at line 434 of file TorchWrapper.cpp.

References get_device_string(), load_module(), and should_use_half().

Referenced by TorchWarmer::warmup_torch().

436  {
437  const std::string compute_device = get_device_string(use_gpu, device_num);
438  const bool use_half = should_use_half(use_gpu, model_path);
439  const auto input_data_type = use_half ? torch::kHalf : torch::kFloat32;
440 
441  try {
442  auto module =
443  load_module(model_path, compute_device, input_data_type, use_gpu /* use_cache */);
444  const auto module_named_params = module->named_parameters(true);
445  const size_t num_named_params = module_named_params.size();
446  std::cout << "Module # params: " << num_named_params << std::endl;
447  const size_t max_params_to_print{1000};
448  size_t param_idx{0};
449  for (const auto& param : module_named_params) {
450  std::cout << param.name << std::endl;
451  if (param_idx++ == max_params_to_print) {
452  break;
453  }
454  }
455  const auto module_named_buffers = module->named_buffers(true);
456  const size_t num_named_buffers = module_named_buffers.size();
457  std::cout << "Module # named buffers: " << num_named_buffers << std::endl;
458  const auto module_named_children = module->named_children();
459  const size_t num_named_children = module_named_children.size();
460  std::cout << "Module # named children: " << num_named_children << std::endl;
461  std::cout << "Finishing torch warmup" << std::endl;
462  } catch (std::exception& e) {
463  std::string error_msg{"Error fetching Torch model params: "};
464  error_msg += e.what();
465  std::cout << error_msg << std::endl;
466  }
467 }
std::string get_device_string(const bool use_gpu, const int64_t device_num)
bool should_use_half(const bool use_gpu, const std::string &model_path)
std::shared_ptr< torch::jit::script::Module > load_module(const std::string &model_path, const std::string compute_device, const at::ScalarType data_type, const bool use_cache)

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

std::vector<Detection> process_detections ( const torch::Tensor &  raw_detections,
const float  min_confidence_threshold,
const float  iou_threshold,
const ModelInfo model_info,
const RasterFormat_Namespace::RasterInfo raster_info,
std::shared_ptr< CpuTimer timer 
)

Definition at line 237 of file TorchWrapper.cpp.

References cat(), centroid_x, centroid_y, class_idx, ModelInfo::class_labels, RasterFormat_Namespace::RasterInfo::halo_x_pixels_per_tile_boundary, RasterFormat_Namespace::RasterInfo::halo_y_pixels_per_tile_boundary, height, logical_and(), nms_kernel(), score, width, world_scale_detections(), RasterFormat_Namespace::RasterInfo::x_pixels_per_tile, RasterFormat_Namespace::RasterInfo::x_tiles, xywh2xyxy(), RasterFormat_Namespace::RasterInfo::y_pixels_per_tile, and RasterFormat_Namespace::RasterInfo::y_tiles.

Referenced by detect_objects_in_tiled_raster_impl().

243  {
244  // Most of logic in this function borrowed liberally from
245  // https://github.com/yasenh/libtorch-yolov5 (MIT Licensed)
246  timer->start_event_timer("Confidence mask");
247  constexpr int64_t item_attr_size = 5;
248  const auto& class_labels = model_info.class_labels;
249  const int32_t num_class_labels = static_cast<int32_t>(class_labels.size());
250  const auto batch_size = raster_info.x_tiles * raster_info.y_tiles;
251  const auto num_classes = raw_detections.size(2) - item_attr_size;
252  auto conf_mask = raw_detections.select(2, 4).ge(min_confidence_threshold).unsqueeze(2);
253  torch::Tensor all_world_scaled_detections;
254  for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
255  auto masked_detections =
256  torch::masked_select(raw_detections[batch_idx], conf_mask[batch_idx])
257  .view({-1, num_classes + item_attr_size});
258 
259  if (masked_detections.size(0) == 0) {
260  continue;
261  }
262  // compute overall score = obj_conf * cls_conf, similar to x[:, 5:] *= x[:, 4:5]
263  masked_detections.slice(1, item_attr_size, item_attr_size + num_classes) *=
264  masked_detections.select(1, 4).unsqueeze(1);
265 
266  // [best class only] get the max classes score at each result (e.g. elements 5-84)
267  std::tuple<torch::Tensor, torch::Tensor> max_classes = torch::max(
268  masked_detections.slice(1, item_attr_size, item_attr_size + num_classes), 1);
269 
270  // class score
271  auto max_conf_scores = std::get<0>(max_classes);
272  // index
273  auto max_conf_classes = std::get<1>(max_classes);
274 
275  max_conf_scores = max_conf_scores.to(torch::kFloat).unsqueeze(1);
276  max_conf_classes = max_conf_classes.to(torch::kFloat).unsqueeze(1);
277  masked_detections = torch::cat(
278  {masked_detections.slice(1, 0, 4), max_conf_classes, max_conf_scores}, 1);
279 
280  if (raster_info.halo_x_pixels_per_tile_boundary > 0 ||
281  raster_info.halo_y_pixels_per_tile_boundary > 0) {
282  const double min_x_pixel = raster_info.halo_x_pixels_per_tile_boundary;
283  const double max_x_pixel =
284  raster_info.x_pixels_per_tile - 1 - raster_info.halo_x_pixels_per_tile_boundary;
285  const double min_y_pixel = raster_info.halo_y_pixels_per_tile_boundary;
286  const double max_y_pixel =
287  raster_info.y_pixels_per_tile - 1 - raster_info.halo_y_pixels_per_tile_boundary;
288  auto x_halo_mask =
289  torch::logical_and(masked_detections.select(1, 0).ge(min_x_pixel),
290  masked_detections.select(1, 0).le(max_x_pixel));
291  auto y_halo_mask =
292  torch::logical_and(masked_detections.select(1, 1).ge(min_y_pixel),
293  masked_detections.select(1, 1).le(max_y_pixel));
294 
295  auto halo_mask = torch::logical_and(x_halo_mask, y_halo_mask).unsqueeze(1);
296  masked_detections =
297  torch::masked_select(masked_detections, halo_mask).view({-1, 6});
298  }
299 
300  auto world_scaled_detections =
301  world_scale_detections(masked_detections, batch_idx, raster_info);
302 
303  auto world_scaled_detections_cpu = world_scaled_detections.cpu();
304  if (batch_idx == 0) {
305  all_world_scaled_detections = world_scaled_detections_cpu.cpu();
306  } else {
307  all_world_scaled_detections =
308  torch::cat({all_world_scaled_detections, world_scaled_detections_cpu}, 0).cpu();
309  }
310  }
311  timer->start_event_timer("Per-batch processing");
312  std::vector<Detection> processed_detections;
313  if (all_world_scaled_detections.size(0) == 0) {
314  return processed_detections;
315  }
316 
317  torch::Tensor bboxes = xywh2xyxy(all_world_scaled_detections.slice(1, 0, 4));
318 
319  auto kept_bboxes_idxs =
320  nms_kernel(bboxes, all_world_scaled_detections.select(1, 5), iou_threshold);
321 
322  timer->start_event_timer("Nms processing");
323 
324  const int64_t num_kept_detections = kept_bboxes_idxs.size(0);
325  processed_detections.reserve(num_kept_detections);
326 
327  const auto& kept_bboxes_idxs_accessor = kept_bboxes_idxs.accessor<int64_t, 1>();
328  const auto& detections_array_accessor =
329  all_world_scaled_detections.accessor<double, 2>();
330 
331  for (int64_t detection_idx = 0; detection_idx < num_kept_detections; ++detection_idx) {
332  int64_t kept_detection_idx = kept_bboxes_idxs_accessor[detection_idx];
333  const auto& detection_array = detections_array_accessor[kept_detection_idx];
334  const int32_t class_idx =
335  static_cast<int32_t>(round(detection_array[DetectionIdx::class_idx]));
336  std::string class_label;
337  if (class_idx < num_class_labels) {
338  class_label = class_labels[class_idx];
339  }
340  Detection processed_detection{
341  class_idx,
342  class_label,
343  detection_array[DetectionIdx::centroid_x],
344  detection_array[DetectionIdx::centroid_y],
345  detection_array[DetectionIdx::width],
346  detection_array[DetectionIdx::height],
347  static_cast<float>(detection_array[DetectionIdx::score])};
348  processed_detections.emplace_back(processed_detection);
349  }
350  timer->start_event_timer("Output processing");
351  return processed_detections;
352 }
torch::Tensor xywh2xyxy(const torch::Tensor &x)
std::string cat(Ts &&...args)
RUNTIME_EXPORT ALWAYS_INLINE int8_t logical_and(const int8_t lhs, const int8_t rhs, const int8_t null_val)
torch::Tensor world_scale_detections(const torch::Tensor &input, const int64_t batch_idx, const RasterFormat_Namespace::RasterInfo &raster_info)
std::vector< std::string > class_labels
Definition: TorchWrapper.h:53
const int64_t halo_y_pixels_per_tile_boundary
Definition: RasterInfo.h:30
at::Tensor nms_kernel(const at::Tensor &dets, const at::Tensor &scores, double iou_threshold)
Definition: TorchOps.hpp:81
const int64_t halo_x_pixels_per_tile_boundary
Definition: RasterInfo.h:29

+ Here is the call graph for this function:

+ Here is the caller graph for this function:

bool should_use_half ( const bool  use_gpu,
const std::string &  model_path 
)

Definition at line 52 of file TorchWrapper.cpp.

Referenced by detect_objects_in_tiled_raster_impl(), and print_model_params().

52  {
53  bool use_half = false;
54 #ifdef HAVE_CUDA_TORCH
55  if (use_gpu && model_path.find("half") != std::string::npos) {
56  use_half = true;
57  }
58 #endif
59  return use_half;
60 }

+ Here is the caller graph for this function:

torch::Tensor world_scale_detections ( const torch::Tensor &  input,
const int64_t  batch_idx,
const RasterFormat_Namespace::RasterInfo raster_info 
)

Definition at line 206 of file TorchWrapper.cpp.

References centroid_x, centroid_y, class_idx, RasterFormat_Namespace::RasterInfo::halo_x_pixels_per_tile_boundary, RasterFormat_Namespace::RasterInfo::halo_y_pixels_per_tile_boundary, height, RasterFormat_Namespace::RasterInfo::logical_x_pixels_per_tile, RasterFormat_Namespace::RasterInfo::logical_y_pixels_per_tile, RasterFormat_Namespace::RasterInfo::min_x_input, RasterFormat_Namespace::RasterInfo::min_y_input, score, width, RasterFormat_Namespace::RasterInfo::x_input_units_per_pixel, RasterFormat_Namespace::RasterInfo::x_tiles, and RasterFormat_Namespace::RasterInfo::y_input_units_per_pixel.

Referenced by process_detections().

209  {
210  const int64_t tile_y_idx = batch_idx / raster_info.x_tiles;
211  const int64_t tile_x_idx = batch_idx % raster_info.x_tiles;
212  const double tile_x0_pixel = tile_x_idx * raster_info.logical_x_pixels_per_tile -
214  const double tile_y0_pixel = tile_y_idx * raster_info.logical_y_pixels_per_tile -
216  auto options = torch::TensorOptions().dtype(torch::kFloat64);
217  //.device(torch::kCPU);
218 
219  auto output = torch::zeros_like(input, options);
220  output.select(1, DetectionIdx::centroid_x) =
221  (input.select(1, DetectionIdx::centroid_x) + tile_x0_pixel) *
222  raster_info.x_input_units_per_pixel +
223  raster_info.min_x_input;
224  output.select(1, DetectionIdx::centroid_y) =
225  (input.select(1, DetectionIdx::centroid_y) + tile_y0_pixel) *
226  raster_info.y_input_units_per_pixel +
227  raster_info.min_y_input;
228  output.select(1, DetectionIdx::width) =
229  input.select(1, DetectionIdx::width) * raster_info.x_input_units_per_pixel;
230  output.select(1, DetectionIdx::height) =
231  input.select(1, DetectionIdx::height) * raster_info.y_input_units_per_pixel;
232  output.select(1, DetectionIdx::class_idx) = input.select(1, DetectionIdx::class_idx);
233  output.select(1, DetectionIdx::score) = input.select(1, DetectionIdx::score);
234  return output;
235 }
const int64_t logical_y_pixels_per_tile
Definition: RasterInfo.h:32
const int64_t logical_x_pixels_per_tile
Definition: RasterInfo.h:31
const int64_t halo_y_pixels_per_tile_boundary
Definition: RasterInfo.h:30
const int64_t halo_x_pixels_per_tile_boundary
Definition: RasterInfo.h:29

+ Here is the caller graph for this function:

torch::Tensor xywh2xyxy ( const torch::Tensor &  x)

Definition at line 195 of file TorchWrapper.cpp.

References br_x, br_y, tl_x, and tl_y.

Referenced by process_detections().

195  {
196  auto y = torch::zeros_like(x);
197  // convert bounding box format from (center x, center y, width, height) to (x1, y1, x2,
198  // y2)
199  y.select(1, BoxDetectionIdx::tl_x) = x.select(1, 0) - x.select(1, 2).div(2);
200  y.select(1, BoxDetectionIdx::tl_y) = x.select(1, 1) - x.select(1, 3).div(2);
201  y.select(1, BoxDetectionIdx::br_x) = x.select(1, 0) + x.select(1, 2).div(2);
202  y.select(1, BoxDetectionIdx::br_y) = x.select(1, 1) + x.select(1, 3).div(2);
203  return y;
204 }

+ Here is the caller graph for this function:

Variable Documentation

std::unordered_map<std::string, std::shared_ptr<torch::jit::script::Module> > model_cache
static

Definition at line 63 of file TorchWrapper.cpp.

Referenced by add_model_to_cache(), and get_model_from_cache().

std::shared_mutex model_mutex
static

Definition at line 64 of file TorchWrapper.cpp.

Referenced by add_model_to_cache(), and get_model_from_cache().