File indexing completed on 2025-01-26 04:25:02
0001 /* 0002 * file: KDTree.hpp 0003 * author: J. Frederico Carvalho 0004 * 0005 * This is an adaptation of the KD-tree implementation in rosetta code 0006 * https://rosettacode.org/wiki/K-d_tree 0007 * 0008 * It is a reimplementation of the C code using C++. It also includes a few 0009 * more queries than the original, namely finding all points at a distance 0010 * smaller than some given distance to a point. 0011 * 0012 */ 0013 0014 #include <algorithm> 0015 #include <cmath> 0016 #include <functional> 0017 #include <iterator> 0018 #include <limits> 0019 #include <memory> 0020 #include <vector> 0021 0022 #include "kdtree.hpp" 0023 0024 KDNode::KDNode() = default; 0025 0026 KDNode::KDNode(const point_t &pt, const size_t &idx_, const KDNodePtr &left_, 0027 const KDNodePtr &right_) { 0028 x = pt; 0029 index = idx_; 0030 left = left_; 0031 right = right_; 0032 } 0033 0034 KDNode::KDNode(const pointIndex &pi, const KDNodePtr &left_, 0035 const KDNodePtr &right_) { 0036 x = pi.first; 0037 index = pi.second; 0038 left = left_; 0039 right = right_; 0040 } 0041 0042 KDNode::~KDNode() = default; 0043 0044 double KDNode::coord(const size_t &idx) { return x.at(idx); } 0045 KDNode::operator bool() { return (!x.empty()); } 0046 KDNode::operator point_t() { return x; } 0047 KDNode::operator size_t() { return index; } 0048 KDNode::operator pointIndex() { return pointIndex(x, index); } 0049 0050 KDNodePtr NewKDNodePtr() { 0051 KDNodePtr mynode = std::make_shared< KDNode >(); 0052 return mynode; 0053 } 0054 0055 inline double dist2(const point_t &a, const point_t &b) { 0056 double distc = 0; 0057 for (size_t i = 0; i < a.size(); i++) { 0058 double di = a.at(i) - b.at(i); 0059 distc += di * di; 0060 } 0061 return distc; 0062 } 0063 0064 inline double dist2(const KDNodePtr &a, const KDNodePtr &b) { 0065 return dist2(a->x, b->x); 0066 } 0067 0068 inline double dist(const point_t &a, const point_t &b) { 0069 return std::sqrt(dist2(a, b)); 0070 } 0071 0072 inline double dist(const KDNodePtr &a, const KDNodePtr &b) { 0073 return std::sqrt(dist2(a, b)); 0074 } 0075 0076 comparer::comparer(size_t idx_) : idx{idx_} {}; 0077 0078 inline bool comparer::compare_idx(const pointIndex &a, // 0079 const pointIndex &b // 0080 ) { 0081 return (a.first.at(idx) < b.first.at(idx)); // 0082 } 0083 0084 inline void sort_on_idx(const pointIndexArr::iterator &begin, // 0085 const pointIndexArr::iterator &end, // 0086 size_t idx) { 0087 comparer comp(idx); 0088 comp.idx = idx; 0089 0090 using std::placeholders::_1; 0091 using std::placeholders::_2; 0092 0093 std::nth_element(begin, begin + std::distance(begin, end) / 2, 0094 end, std::bind(&comparer::compare_idx, comp, _1, _2)); 0095 } 0096 0097 using pointVec = std::vector< point_t >; 0098 0099 KDNodePtr KDTree::make_tree(const pointIndexArr::iterator &begin, // 0100 const pointIndexArr::iterator &end, // 0101 const size_t &length, // 0102 const size_t &level // 0103 ) { 0104 if (begin == end) { 0105 return NewKDNodePtr(); // empty tree 0106 } 0107 0108 size_t dim = begin->first.size(); 0109 0110 if (length > 1) { 0111 sort_on_idx(begin, end, level); 0112 } 0113 0114 auto middle = begin + (length / 2); 0115 0116 auto l_begin = begin; 0117 auto l_end = middle; 0118 auto r_begin = middle + 1; 0119 auto r_end = end; 0120 0121 size_t l_len = length / 2; 0122 size_t r_len = length - l_len - 1; 0123 0124 KDNodePtr left; 0125 if (l_len > 0 && dim > 0) { 0126 left = make_tree(l_begin, l_end, l_len, (level + 1) % dim); 0127 } else { 0128 left = leaf; 0129 } 0130 KDNodePtr right; 0131 if (r_len > 0 && dim > 0) { 0132 right = make_tree(r_begin, r_end, r_len, (level + 1) % dim); 0133 } else { 0134 right = leaf; 0135 } 0136 0137 // KDNode result = KDNode(); 0138 return std::make_shared< KDNode >(*middle, left, right); 0139 } 0140 0141 KDTree::KDTree(pointVec point_array) { 0142 leaf = std::make_shared< KDNode >(); 0143 // iterators 0144 pointIndexArr arr; 0145 for (size_t i = 0; i < point_array.size(); i++) { 0146 arr.push_back(pointIndex(point_array.at(i), i)); 0147 } 0148 0149 auto begin = arr.begin(); 0150 auto end = arr.end(); 0151 0152 size_t length = arr.size(); 0153 size_t level = 0; // starting 0154 0155 root = KDTree::make_tree(begin, end, length, level); 0156 0157 m_empty = point_array.size(); 0158 } 0159 0160 KDNodePtr KDTree::nearest_( // 0161 const KDNodePtr &branch, // 0162 const point_t &pt, // 0163 const size_t &level, // 0164 const KDNodePtr &best, // 0165 const double &best_dist // 0166 ) { 0167 double d, dx, dx2; 0168 0169 if (!bool(*branch)) { 0170 return NewKDNodePtr(); // basically, null 0171 } 0172 0173 point_t branch_pt(*branch); 0174 size_t dim = branch_pt.size(); 0175 0176 d = dist2(branch_pt, pt); 0177 dx = branch_pt.at(level) - pt.at(level); 0178 dx2 = dx * dx; 0179 0180 KDNodePtr best_l = best; 0181 double best_dist_l = best_dist; 0182 0183 if (d < best_dist) { 0184 best_dist_l = d; 0185 best_l = branch; 0186 } 0187 0188 size_t next_lv = (level + 1) % dim; 0189 KDNodePtr section; 0190 KDNodePtr other; 0191 0192 // select which branch makes sense to check 0193 if (dx > 0) { 0194 section = branch->left; 0195 other = branch->right; 0196 } else { 0197 section = branch->right; 0198 other = branch->left; 0199 } 0200 0201 // keep nearest neighbor from further down the tree 0202 KDNodePtr further = nearest_(section, pt, next_lv, best_l, best_dist_l); 0203 if (!further->x.empty()) { 0204 double dl = dist2(further->x, pt); 0205 if (dl < best_dist_l) { 0206 best_dist_l = dl; 0207 best_l = further; 0208 } 0209 } 0210 // only check the other branch if it makes sense to do so 0211 if (dx2 < best_dist_l) { 0212 further = nearest_(other, pt, next_lv, best_l, best_dist_l); 0213 if (!further->x.empty()) { 0214 double dl = dist2(further->x, pt); 0215 if (dl < best_dist_l) { 0216 best_dist_l = dl; 0217 best_l = further; 0218 } 0219 } 0220 } 0221 0222 return best_l; 0223 }; 0224 0225 // default caller 0226 KDNodePtr KDTree::nearest_(const point_t &pt) { 0227 size_t level = 0; 0228 // KDNodePtr best = branch; 0229 double branch_dist = dist2(point_t(*root), pt); 0230 return nearest_(root, // beginning of tree 0231 pt, // point we are querying 0232 level, // start from level 0 0233 root, // best is the root 0234 branch_dist); // best_dist = branch_dist 0235 }; 0236 0237 point_t KDTree::nearest_point(const point_t &pt) { 0238 return point_t(*nearest_(pt)); 0239 }; 0240 size_t KDTree::nearest_index(const point_t &pt) { 0241 return size_t(*nearest_(pt)); 0242 }; 0243 0244 pointIndex KDTree::nearest_pointIndex(const point_t &pt) { 0245 KDNodePtr Nearest = nearest_(pt); 0246 return pointIndex(point_t(*Nearest), size_t(*Nearest)); 0247 } 0248 0249 bool KDTree::empty() 0250 { 0251 return m_empty; 0252 } 0253 0254 pointIndexArr KDTree::neighborhood_( // 0255 const KDNodePtr &branch, // 0256 const point_t &pt, // 0257 const double &rad, // 0258 const size_t &level // 0259 ) { 0260 double d, dx, dx2; 0261 0262 if (!bool(*branch)) { 0263 // branch has no point, means it is a leaf, 0264 // no points to add 0265 return pointIndexArr(); 0266 } 0267 0268 size_t dim = pt.size(); 0269 0270 double r2 = rad * rad; 0271 0272 d = dist2(point_t(*branch), pt); 0273 dx = point_t(*branch).at(level) - pt.at(level); 0274 dx2 = dx * dx; 0275 0276 pointIndexArr nbh, nbh_s, nbh_o; 0277 if (d <= r2) { 0278 nbh.push_back(pointIndex(*branch)); 0279 } 0280 0281 // 0282 KDNodePtr section; 0283 KDNodePtr other; 0284 if (dx > 0) { 0285 section = branch->left; 0286 other = branch->right; 0287 } else { 0288 section = branch->right; 0289 other = branch->left; 0290 } 0291 0292 nbh_s = neighborhood_(section, pt, rad, (level + 1) % dim); 0293 nbh.insert(nbh.end(), nbh_s.begin(), nbh_s.end()); 0294 if (dx2 < r2) { 0295 nbh_o = neighborhood_(other, pt, rad, (level + 1) % dim); 0296 nbh.insert(nbh.end(), nbh_o.begin(), nbh_o.end()); 0297 } 0298 0299 return nbh; 0300 }; 0301 0302 pointIndexArr KDTree::neighborhood( // 0303 const point_t &pt, // 0304 const double &rad) { 0305 size_t level = 0; 0306 return neighborhood_(root, pt, rad, level); 0307 } 0308 0309 pointVec KDTree::neighborhood_points( // 0310 const point_t &pt, // 0311 const double &rad) { 0312 size_t level = 0; 0313 pointIndexArr nbh = neighborhood_(root, pt, rad, level); 0314 pointVec nbhp; 0315 nbhp.resize(nbh.size()); 0316 std::transform(nbh.begin(), nbh.end(), nbhp.begin(), 0317 [](pointIndex x) { return x.first; }); 0318 return nbhp; 0319 } 0320 0321 indexArr KDTree::neighborhood_indices( // 0322 const point_t &pt, // 0323 const double &rad) { 0324 size_t level = 0; 0325 pointIndexArr nbh = neighborhood_(root, pt, rad, level); 0326 indexArr nbhi; 0327 nbhi.resize(nbh.size()); 0328 std::transform(nbh.begin(), nbh.end(), nbhi.begin(), 0329 [](pointIndex x) { return x.second; }); 0330 return nbhi; 0331 }