File indexing completed on 2025-01-26 04:25:02

0001 #pragma once
0002 
0003 /*
0004  * file: KDTree.hpp
0005  * author: J. Frederico Carvalho
0006  *
0007  * This is an adaptation of the KD-tree implementation in rosetta code
0008  *  https://rosettacode.org/wiki/K-d_tree
0009  * It is a reimplementation of the C code using C++.
0010  * It also includes a few more queries than the original
0011  *
0012  */
0013 
0014 #include <algorithm>
0015 #include <functional>
0016 #include <memory>
0017 #include <vector>
0018 
0019 using point_t = std::vector< double >;
0020 using indexArr = std::vector< size_t >;
0021 using pointIndex = typename std::pair< std::vector< double >, size_t >;
0022 
0023 class KDNode {
0024    public:
0025     using KDNodePtr = std::shared_ptr< KDNode >;
0026     size_t index;
0027     point_t x;
0028     KDNodePtr left;
0029     KDNodePtr right;
0030 
0031     // initializer
0032     KDNode();
0033     KDNode(const point_t &, const size_t &, const KDNodePtr &,
0034            const KDNodePtr &);
0035     KDNode(const pointIndex &, const KDNodePtr &, const KDNodePtr &);
0036     ~KDNode();
0037 
0038     // getter
0039     double coord(const size_t &);
0040 
0041     // conversions
0042     explicit operator bool();
0043     explicit operator point_t();
0044     explicit operator size_t();
0045     explicit operator pointIndex();
0046 };
0047 
0048 using KDNodePtr = std::shared_ptr< KDNode >;
0049 
0050 KDNodePtr NewKDNodePtr();
0051 
0052 // square euclidean distance
0053 inline double dist2(const point_t &, const point_t &);
0054 inline double dist2(const KDNodePtr &, const KDNodePtr &);
0055 
0056 // euclidean distance
0057 inline double dist(const point_t &, const point_t &);
0058 inline double dist(const KDNodePtr &, const KDNodePtr &);
0059 
0060 // Need for sorting
0061 class comparer {
0062    public:
0063     size_t idx;
0064     explicit comparer(size_t idx_);
0065     inline bool compare_idx(
0066         const std::pair< std::vector< double >, size_t > &,  //
0067         const std::pair< std::vector< double >, size_t > &   //
0068     );
0069 };
0070 
0071 using pointIndexArr = typename std::vector< pointIndex >;
0072 
0073 inline void sort_on_idx(const pointIndexArr::iterator &,  //
0074                         const pointIndexArr::iterator &,  //
0075                         size_t idx);
0076 
0077 using pointVec = std::vector< point_t >;
0078 
0079 class KDTree {
0080     KDNodePtr root;
0081     KDNodePtr leaf;
0082 
0083     KDNodePtr make_tree(const pointIndexArr::iterator &begin,  //
0084                         const pointIndexArr::iterator &end,    //
0085                         const size_t &length,                  //
0086                         const size_t &level                    //
0087     );
0088 
0089    public:
0090     KDTree() = default;
0091     explicit KDTree(pointVec point_array);
0092 
0093    private:
0094     KDNodePtr nearest_(           //
0095         const KDNodePtr &branch,  //
0096         const point_t &pt,        //
0097         const size_t &level,      //
0098         const KDNodePtr &best,    //
0099         const double &best_dist   //
0100     );
0101 
0102     bool m_empty {true};
0103     // default caller
0104     KDNodePtr nearest_(const point_t &pt);
0105 
0106    public:
0107     point_t nearest_point(const point_t &pt);
0108     size_t nearest_index(const point_t &pt);
0109     pointIndex nearest_pointIndex(const point_t &pt);
0110     bool empty();
0111 
0112    private:
0113     pointIndexArr neighborhood_(  //
0114         const KDNodePtr &branch,  //
0115         const point_t &pt,        //
0116         const double &rad,        //
0117         const size_t &level       //
0118     );
0119 
0120    public:
0121     pointIndexArr neighborhood(  //
0122         const point_t &pt,       //
0123         const double &rad);
0124 
0125     pointVec neighborhood_points(  //
0126         const point_t &pt,         //
0127         const double &rad);
0128 
0129     indexArr neighborhood_indices(  //
0130         const point_t &pt,          //
0131         const double &rad);
0132 };