File indexing completed on 2025-01-26 04:25:02

0001 /*
0002  * file: KDTree.hpp
0003  * author: J. Frederico Carvalho
0004  *
0005  * This is an adaptation of the KD-tree implementation in rosetta code
0006  * https://rosettacode.org/wiki/K-d_tree
0007  *
0008  * It is a reimplementation of the C code using C++.  It also includes a few
0009  * more queries than the original, namely finding all points at a distance
0010  * smaller than some given distance to a point.
0011  *
0012  */
0013 
0014 #include <algorithm>
0015 #include <cmath>
0016 #include <functional>
0017 #include <iterator>
0018 #include <limits>
0019 #include <memory>
0020 #include <vector>
0021 
0022 #include "kdtree.hpp"
0023 
0024 KDNode::KDNode() = default;
0025 
0026 KDNode::KDNode(const point_t &pt, const size_t &idx_, const KDNodePtr &left_,
0027                const KDNodePtr &right_) {
0028     x = pt;
0029     index = idx_;
0030     left = left_;
0031     right = right_;
0032 }
0033 
0034 KDNode::KDNode(const pointIndex &pi, const KDNodePtr &left_,
0035                const KDNodePtr &right_) {
0036     x = pi.first;
0037     index = pi.second;
0038     left = left_;
0039     right = right_;
0040 }
0041 
0042 KDNode::~KDNode() = default;
0043 
0044 double KDNode::coord(const size_t &idx) { return x.at(idx); }
0045 KDNode::operator bool() { return (!x.empty()); }
0046 KDNode::operator point_t() { return x; }
0047 KDNode::operator size_t() { return index; }
0048 KDNode::operator pointIndex() { return pointIndex(x, index); }
0049 
0050 KDNodePtr NewKDNodePtr() {
0051     KDNodePtr mynode = std::make_shared< KDNode >();
0052     return mynode;
0053 }
0054 
0055 inline double dist2(const point_t &a, const point_t &b) {
0056     double distc = 0;
0057     for (size_t i = 0; i < a.size(); i++) {
0058         double di = a.at(i) - b.at(i);
0059         distc += di * di;
0060     }
0061     return distc;
0062 }
0063 
0064 inline double dist2(const KDNodePtr &a, const KDNodePtr &b) {
0065     return dist2(a->x, b->x);
0066 }
0067 
0068 inline double dist(const point_t &a, const point_t &b) {
0069     return std::sqrt(dist2(a, b));
0070 }
0071 
0072 inline double dist(const KDNodePtr &a, const KDNodePtr &b) {
0073     return std::sqrt(dist2(a, b));
0074 }
0075 
0076 comparer::comparer(size_t idx_) : idx{idx_} {};
0077 
0078 inline bool comparer::compare_idx(const pointIndex &a,  //
0079                                   const pointIndex &b   //
0080 ) {
0081     return (a.first.at(idx) < b.first.at(idx));  //
0082 }
0083 
0084 inline void sort_on_idx(const pointIndexArr::iterator &begin,  //
0085                         const pointIndexArr::iterator &end,    //
0086                         size_t idx) {
0087     comparer comp(idx);
0088     comp.idx = idx;
0089 
0090     using std::placeholders::_1;
0091     using std::placeholders::_2;
0092 
0093     std::nth_element(begin, begin + std::distance(begin, end) / 2,
0094                      end, std::bind(&comparer::compare_idx, comp, _1, _2));
0095 }
0096 
0097 using pointVec = std::vector< point_t >;
0098 
0099 KDNodePtr KDTree::make_tree(const pointIndexArr::iterator &begin,  //
0100                             const pointIndexArr::iterator &end,    //
0101                             const size_t &length,                  //
0102                             const size_t &level                    //
0103 ) {
0104     if (begin == end) {
0105         return NewKDNodePtr();  // empty tree
0106     }
0107 
0108     size_t dim = begin->first.size();
0109 
0110     if (length > 1) {
0111         sort_on_idx(begin, end, level);
0112     }
0113 
0114     auto middle = begin + (length / 2);
0115 
0116     auto l_begin = begin;
0117     auto l_end = middle;
0118     auto r_begin = middle + 1;
0119     auto r_end = end;
0120 
0121     size_t l_len = length / 2;
0122     size_t r_len = length - l_len - 1;
0123 
0124     KDNodePtr left;
0125     if (l_len > 0 && dim > 0) {
0126         left = make_tree(l_begin, l_end, l_len, (level + 1) % dim);
0127     } else {
0128         left = leaf;
0129     }
0130     KDNodePtr right;
0131     if (r_len > 0 && dim > 0) {
0132         right = make_tree(r_begin, r_end, r_len, (level + 1) % dim);
0133     } else {
0134         right = leaf;
0135     }
0136 
0137     // KDNode result = KDNode();
0138     return std::make_shared< KDNode >(*middle, left, right);
0139 }
0140 
0141 KDTree::KDTree(pointVec point_array) {
0142     leaf = std::make_shared< KDNode >();
0143     // iterators
0144     pointIndexArr arr;
0145     for (size_t i = 0; i < point_array.size(); i++) {
0146         arr.push_back(pointIndex(point_array.at(i), i));
0147     }
0148 
0149     auto begin = arr.begin();
0150     auto end = arr.end();
0151 
0152     size_t length = arr.size();
0153     size_t level = 0;  // starting
0154 
0155     root = KDTree::make_tree(begin, end, length, level);
0156 
0157     m_empty = point_array.size();
0158 }
0159 
0160 KDNodePtr KDTree::nearest_(   //
0161     const KDNodePtr &branch,  //
0162     const point_t &pt,        //
0163     const size_t &level,      //
0164     const KDNodePtr &best,    //
0165     const double &best_dist   //
0166 ) {
0167     double d, dx, dx2;
0168 
0169     if (!bool(*branch)) {
0170         return NewKDNodePtr();  // basically, null
0171     }
0172 
0173     point_t branch_pt(*branch);
0174     size_t dim = branch_pt.size();
0175 
0176     d = dist2(branch_pt, pt);
0177     dx = branch_pt.at(level) - pt.at(level);
0178     dx2 = dx * dx;
0179 
0180     KDNodePtr best_l = best;
0181     double best_dist_l = best_dist;
0182 
0183     if (d < best_dist) {
0184         best_dist_l = d;
0185         best_l = branch;
0186     }
0187 
0188     size_t next_lv = (level + 1) % dim;
0189     KDNodePtr section;
0190     KDNodePtr other;
0191 
0192     // select which branch makes sense to check
0193     if (dx > 0) {
0194         section = branch->left;
0195         other = branch->right;
0196     } else {
0197         section = branch->right;
0198         other = branch->left;
0199     }
0200 
0201     // keep nearest neighbor from further down the tree
0202     KDNodePtr further = nearest_(section, pt, next_lv, best_l, best_dist_l);
0203     if (!further->x.empty()) {
0204         double dl = dist2(further->x, pt);
0205         if (dl < best_dist_l) {
0206             best_dist_l = dl;
0207             best_l = further;
0208         }
0209     }
0210     // only check the other branch if it makes sense to do so
0211     if (dx2 < best_dist_l) {
0212         further = nearest_(other, pt, next_lv, best_l, best_dist_l);
0213         if (!further->x.empty()) {
0214             double dl = dist2(further->x, pt);
0215             if (dl < best_dist_l) {
0216                 best_dist_l = dl;
0217                 best_l = further;
0218             }
0219         }
0220     }
0221 
0222     return best_l;
0223 };
0224 
0225 // default caller
0226 KDNodePtr KDTree::nearest_(const point_t &pt) {
0227     size_t level = 0;
0228     // KDNodePtr best = branch;
0229     double branch_dist = dist2(point_t(*root), pt);
0230     return nearest_(root,          // beginning of tree
0231                     pt,            // point we are querying
0232                     level,         // start from level 0
0233                     root,          // best is the root
0234                     branch_dist);  // best_dist = branch_dist
0235 };
0236 
0237 point_t KDTree::nearest_point(const point_t &pt) {
0238     return point_t(*nearest_(pt));
0239 };
0240 size_t KDTree::nearest_index(const point_t &pt) {
0241     return size_t(*nearest_(pt));
0242 };
0243 
0244 pointIndex KDTree::nearest_pointIndex(const point_t &pt) {
0245     KDNodePtr Nearest = nearest_(pt);
0246     return pointIndex(point_t(*Nearest), size_t(*Nearest));
0247 }
0248 
0249 bool KDTree::empty()
0250 {
0251     return m_empty;
0252 }
0253 
0254 pointIndexArr KDTree::neighborhood_(  //
0255     const KDNodePtr &branch,          //
0256     const point_t &pt,                //
0257     const double &rad,                //
0258     const size_t &level               //
0259 ) {
0260     double d, dx, dx2;
0261 
0262     if (!bool(*branch)) {
0263         // branch has no point, means it is a leaf,
0264         // no points to add
0265         return pointIndexArr();
0266     }
0267 
0268     size_t dim = pt.size();
0269 
0270     double r2 = rad * rad;
0271 
0272     d = dist2(point_t(*branch), pt);
0273     dx = point_t(*branch).at(level) - pt.at(level);
0274     dx2 = dx * dx;
0275 
0276     pointIndexArr nbh, nbh_s, nbh_o;
0277     if (d <= r2) {
0278         nbh.push_back(pointIndex(*branch));
0279     }
0280 
0281     //
0282     KDNodePtr section;
0283     KDNodePtr other;
0284     if (dx > 0) {
0285         section = branch->left;
0286         other = branch->right;
0287     } else {
0288         section = branch->right;
0289         other = branch->left;
0290     }
0291 
0292     nbh_s = neighborhood_(section, pt, rad, (level + 1) % dim);
0293     nbh.insert(nbh.end(), nbh_s.begin(), nbh_s.end());
0294     if (dx2 < r2) {
0295         nbh_o = neighborhood_(other, pt, rad, (level + 1) % dim);
0296         nbh.insert(nbh.end(), nbh_o.begin(), nbh_o.end());
0297     }
0298 
0299     return nbh;
0300 };
0301 
0302 pointIndexArr KDTree::neighborhood(  //
0303     const point_t &pt,               //
0304     const double &rad) {
0305     size_t level = 0;
0306     return neighborhood_(root, pt, rad, level);
0307 }
0308 
0309 pointVec KDTree::neighborhood_points(  //
0310     const point_t &pt,                 //
0311     const double &rad) {
0312     size_t level = 0;
0313     pointIndexArr nbh = neighborhood_(root, pt, rad, level);
0314     pointVec nbhp;
0315     nbhp.resize(nbh.size());
0316     std::transform(nbh.begin(), nbh.end(), nbhp.begin(),
0317                    [](pointIndex x) { return x.first; });
0318     return nbhp;
0319 }
0320 
0321 indexArr KDTree::neighborhood_indices(  //
0322     const point_t &pt,                  //
0323     const double &rad) {
0324     size_t level = 0;
0325     pointIndexArr nbh = neighborhood_(root, pt, rad, level);
0326     indexArr nbhi;
0327     nbhi.resize(nbh.size());
0328     std::transform(nbh.begin(), nbh.end(), nbhi.begin(),
0329                    [](pointIndex x) { return x.second; });
0330     return nbhi;
0331 }