OmniSciDB  72c90bc290
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
TorchOps.hpp File Reference
#include <ATen/ATen.h>
#include <torch/library.h>
+ Include dependency graph for TorchOps.hpp:
+ This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

template<typename scalar_t >
at::Tensor nms_kernel_impl (const at::Tensor &dets, const at::Tensor &scores, double iou_threshold)
 
at::Tensor nms_kernel (const at::Tensor &dets, const at::Tensor &scores, double iou_threshold)
 

Function Documentation

at::Tensor nms_kernel ( const at::Tensor &  dets,
const at::Tensor &  scores,
double  iou_threshold 
)

Definition at line 81 of file TorchOps.hpp.

References run_benchmark_import::result.

Referenced by process_detections().

83  {
84  TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
85  TORCH_CHECK(dets.size(1) == 4,
86  "boxes should have 4 elements in dimension 1, got ",
87  dets.size(1));
88  TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D");
89  TORCH_CHECK(dets.size(0) == scores.size(0),
90  "boxes and scores should have same number of elements in ",
91  "dimension 0, got ",
92  dets.size(0),
93  " and ",
94  scores.size(0));
95 
96  auto result = at::empty({0}, dets.options());
97 
98  AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] {
99  result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
100  });
101  return result;
102 }

+ Here is the caller graph for this function:

template<typename scalar_t >
at::Tensor nms_kernel_impl ( const at::Tensor &  dets,
const at::Tensor &  scores,
double  iou_threshold 
)

Definition at line 14 of file TorchOps.hpp.

16  {
17  TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
18  TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
19  TORCH_CHECK(dets.scalar_type() == scores.scalar_type(),
20  "dets should have the same type as scores");
21 
22  if (dets.numel() == 0)
23  return at::empty({0}, dets.options().dtype(at::kLong));
24 
25  auto x1_t = dets.select(1, 0).contiguous();
26  auto y1_t = dets.select(1, 1).contiguous();
27  auto x2_t = dets.select(1, 2).contiguous();
28  auto y2_t = dets.select(1, 3).contiguous();
29 
30  at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
31 
32  auto order_t =
33  std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
34 
35  auto ndets = dets.size(0);
36  at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
37  at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
38 
39  auto suppressed = suppressed_t.data_ptr<uint8_t>();
40  auto keep = keep_t.data_ptr<int64_t>();
41  auto order = order_t.data_ptr<int64_t>();
42  auto x1 = x1_t.data_ptr<scalar_t>();
43  auto y1 = y1_t.data_ptr<scalar_t>();
44  auto x2 = x2_t.data_ptr<scalar_t>();
45  auto y2 = y2_t.data_ptr<scalar_t>();
46  auto areas = areas_t.data_ptr<scalar_t>();
47 
48  int64_t num_to_keep = 0;
49 
50  for (int64_t _i = 0; _i < ndets; _i++) {
51  auto i = order[_i];
52  if (suppressed[i] == 1)
53  continue;
54  keep[num_to_keep++] = i;
55  auto ix1 = x1[i];
56  auto iy1 = y1[i];
57  auto ix2 = x2[i];
58  auto iy2 = y2[i];
59  auto iarea = areas[i];
60 
61  for (int64_t _j = _i + 1; _j < ndets; _j++) {
62  auto j = order[_j];
63  if (suppressed[j] == 1)
64  continue;
65  auto xx1 = std::max(ix1, x1[j]);
66  auto yy1 = std::max(iy1, y1[j]);
67  auto xx2 = std::min(ix2, x2[j]);
68  auto yy2 = std::min(iy2, y2[j]);
69 
70  auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
71  auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
72  auto inter = w * h;
73  auto ovr = inter / (iarea + areas[j] - inter);
74  if (ovr > iou_threshold)
75  suppressed[j] = 1;
76  }
77  }
78  return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
79 }