SourceXtractorPlusPlus  0.15
Please provide a description of the project.
KdTree.icpp
Go to the documentation of this file.
1 /** Copyright © 2021 Université de Genève, LMU Munich - Faculty of Physics, IAP-CNRS/Sorbonne Université
2  *
3  * This library is free software; you can redistribute it and/or modify it under
4  * the terms of the GNU Lesser General Public License as published by the Free
5  * Software Foundation; either version 3.0 of the License, or (at your option)
6  * any later version.
7  *
8  * This library is distributed in the hope that it will be useful, but WITHOUT
9  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
10  * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
11  * details.
12  *
13  * You should have received a copy of the GNU Lesser General Public License
14  * along with this library; if not, write to the Free Software Foundation, Inc.,
15  * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16  */
17 
18 namespace SourceXtractor {
19 
20 template<typename T, size_t N, size_t S>
21 class KdTree<T, N, S>::Node {
22 public:
23  virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const = 0;
24  virtual ~Node() = default;
25 };
26 
27 template<typename T, size_t N, size_t S>
28 class KdTree<T, N, S>::Leaf : public KdTree::Node {
29 public:
30  explicit Leaf(const std::vector<T>&& data) : m_data(data) {}
31  virtual ~Leaf() = default;
32 
33  virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const {
34  std::vector<T> selection;
35  for (auto& entry : m_data) {
36  double square_dist = 0.0;
37  for (size_t i =0; i < N; i++) {
38  double delta = Traits::getCoord(entry, i) - coord.coord[i];
39  square_dist += delta * delta;
40  }
41  if (square_dist < radius*radius) {
42  selection.push_back(entry);
43  }
44  }
45  return selection;
46  }
47 
48 private:
49  const std::vector<T> m_data;
50 };
51 
52 template<typename T, size_t N, size_t S>
53 class KdTree<T, N, S>::Split : public KdTree::Node {
54 public:
55  virtual ~Split() = default;
56  explicit Split(std::vector<T> data, size_t axis) : m_axis(axis) {
57  std::sort(data.begin(), data.end(), [axis](const T& a, const T& b) -> bool {
58  return Traits::getCoord(a, axis) < Traits::getCoord(b, axis);
59  });
60 
61  double a = Traits::getCoord(data.at(data.size() / 2 - 1), axis);
62  double b = Traits::getCoord(data.at(data.size() / 2), axis);
63 
64  if (a == b) {
65  // avoid a possible rounding issue
66  m_split_value = a;
67  } else {
68  m_split_value = (a + b) / 2.0;
69  }
70 
71  std::vector<T> left(data.begin(), data.begin() + data.size() / 2);
72  std::vector<T> right(data.begin() + data.size() / 2, data.end());
73 
74  if (left.size() > S) {
75  m_left_child = std::make_shared<Split>(std::move(left), (axis+1) % N);
76  } else {
77  m_left_child = std::make_shared<Leaf>(std::move(left));
78  }
79  if (right.size() > S) {
80  m_right_child = std::make_shared<Split>(std::move(right), (axis+1) % N);
81  } else {
82  m_right_child = std::make_shared<Leaf>(std::move(right));
83  }
84  }
85 
86  virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const {
87  if (coord.coord[m_axis] + radius < m_split_value) {
88  return m_left_child->findPointsWithinRadius(coord, radius);
89  } else if (coord.coord[m_axis] - radius > m_split_value) {
90  return m_right_child->findPointsWithinRadius(coord, radius);
91  } else {
92  auto left = m_left_child->findPointsWithinRadius(coord, radius);
93  auto right = m_right_child->findPointsWithinRadius(coord, radius);
94 
95  std::vector<T> merge;
96  merge.reserve(left.size() + right.size());
97  merge.insert(merge.end(), left.begin(), left.end());
98  merge.insert(merge.end(), right.begin(), right.end());
99 
100  return merge;
101  }
102  }
103 
104 private:
105  size_t m_axis;
106  double m_split_value;
107 
108  std::shared_ptr<Node> m_left_child;
109  std::shared_ptr<Node> m_right_child;
110 };
111 
112 template<typename T, size_t N, size_t S>
113 KdTree<T, N, S>::KdTree(const std::vector<T>& data) {
114  if (data.size() > S) {
115  m_root = std::make_shared<Split>(data, 0);
116  } else {
117  std::vector<T> data_copy(data);
118  m_root = std::make_shared<Leaf>(std::move(data_copy));
119  }
120 }
121 
122 template<typename T, size_t N, size_t S>
123 std::vector<T> KdTree<T, N, S>::findPointsWithinRadius(Coord coord, double radius) const {
124  return m_root->findPointsWithinRadius(coord, radius);
125 }
126 
127 }