File indexing completed on 2025-03-09 03:55:01

0001 /* ============================================================
0002  *
0003  * This file is a part of digiKam
0004  *
0005  * Date        : 2019-06-08
0006  * Description : Node of KD-Tree for vector space partitioning
0007  *
0008  * SPDX-FileCopyrightText: 2020 by Nghia Duong <minhnghiaduong997 at gmail dot com>
0009  *
0010  * SPDX-License-Identifier: GPL-2.0-or-later
0011  *
0012  * ============================================================ */
0013 
0014 #include "kd_node.h"
0015 
0016 // C++ include
0017 
0018 #include <cfloat>
0019 #include <cstdio>
0020 #include <iterator>
0021 
0022 // Qt include
0023 
0024 #include <QtMath>
0025 #include <QDebug>
0026 
0027 // Local includes
0028 
0029 #include "dnnfaceextractor.h"
0030 
0031 namespace Digikam
0032 {
0033 
0034 float KDNode::sqrDistance(const float* const pos1, const float* const pos2, int dimension)
0035 {
0036     if (!pos1 || !pos2)
0037     {
0038         return 0.0F;
0039     }
0040 
0041     double sqrDistance = 0.0;
0042 
0043     for (int i = 0 ; i < dimension ; ++i)
0044     {
0045         sqrDistance += pow((pos1[i] - pos2[i]), 2);
0046     }
0047 
0048     return float(sqrDistance);
0049 }
0050 
0051 float KDNode::cosDistance(const float* const pos1, const float* const pos2, int dimension)
0052 {
0053     if (!pos1 || !pos2)
0054     {
0055         return 0.0F;
0056     }
0057 
0058     double scalarProduct = 0.0;
0059     double normV1        = 0.0;
0060     double normV2        = 0.0;
0061 
0062     for (int i = 0 ; i < dimension ; ++i)
0063     {
0064         scalarProduct += pos1[i] * pos2[i];
0065         normV1        += pow(pos1[i], 2);
0066         normV2        += pow(pos2[i], 2);
0067     }
0068 
0069     return float(scalarProduct / (normV1 * normV2));
0070 }
0071 
0072 // ----------------------------------------------------------------------------------------
0073 
0074 class Q_DECL_HIDDEN KDNode::Private
0075 {
0076 public:
0077 
0078     Private(const cv::Mat& nodePos, const int identity, int splitAxis, int dimension)
0079         : nodeID        (-1),
0080           identity      (identity),
0081           splitAxis     (splitAxis),
0082           nbDimension   (dimension),
0083           position      (nodePos.clone()),
0084           maxRange      (nodePos.clone()),
0085           minRange      (nodePos.clone()),
0086           parent        (nullptr),
0087           left          (nullptr),
0088           right         (nullptr)
0089     {
0090     }
0091 
0092     ~Private()
0093     {
0094         delete left;
0095         delete right;
0096     }
0097 
0098 public:
0099 
0100     int     nodeID;
0101     int     identity;
0102     int     splitAxis;
0103     int     nbDimension;
0104 
0105     cv::Mat position;
0106     cv::Mat maxRange;
0107     cv::Mat minRange;
0108     KDNode* parent;
0109     KDNode* left;
0110     KDNode* right;
0111 };
0112 
0113 KDNode::KDNode(const cv::Mat& nodePos,
0114                const int      identity,
0115                int            splitAxis,
0116                int            dimension)
0117     : d(new Private(nodePos, identity, splitAxis, dimension))
0118 {
0119     Q_ASSERT(splitAxis < dimension);
0120     Q_ASSERT((nodePos.rows   == 1)         &&
0121              (nodePos.cols   == dimension) &&
0122              (nodePos.type() == CV_32F));
0123 }
0124 
0125 KDNode::~KDNode()
0126 {
0127     delete d;
0128 }
0129 
0130 KDNode* KDNode::insert(const cv::Mat& nodePos, const int identity)
0131 {
0132     if (!((nodePos.rows   == 1)              &&
0133           (nodePos.cols   == d->nbDimension) &&
0134           (nodePos.type() == CV_32F)))
0135     {
0136         return nullptr;
0137     }
0138 
0139     KDNode* const parent  = findParent(nodePos);
0140 
0141     KDNode* const newNode = new KDNode(nodePos, identity,
0142                                        ((parent->d->splitAxis + 1) % d->nbDimension),
0143                                        d->nbDimension);
0144     newNode->d->parent    = parent;
0145 /*
0146     qCDebug(DIGIKAM_FACESENGINE_LOG) << "parent embedding" << parent->getPosition() << std::endl;
0147     qCDebug(DIGIKAM_FACESENGINE_LOG) << "node embedding" << nodePos << std::endl;
0148 */
0149     if (nodePos.at<float>(0, parent->d->splitAxis) >= parent->getPosition().at<float>(0, parent->d->splitAxis))
0150     {
0151         parent->d->right = newNode;
0152     }
0153     else
0154     {
0155         parent->d->left = newNode;
0156     }
0157 
0158     return newNode;
0159 }
0160 
0161 cv::Mat KDNode::getPosition() const
0162 {
0163     return d->position;
0164 }
0165 
0166 int KDNode::getIdentity()
0167 {
0168     return d->identity;
0169 }
0170 
0171 void KDNode::setNodeId(int id)
0172 {
0173     d->nodeID = id;
0174 }
0175 
0176 double KDNode::getClosestNeighbors(QMap<double, QVector<int> >& neighborList,
0177                                    const cv::Mat&               position,
0178                                    float                        sqRange,
0179                                    float                        cosThreshold,
0180                                    int                          maxNbNeighbors) const
0181 {
0182     if (!position.ptr<float>())
0183     {
0184         return sqRange;
0185     }
0186 
0187     // add current node to the list
0188 
0189     const double sqrDistanceToCurrentNode = sqrDistance(position.ptr<float>(), d->position.ptr<float>(), d->nbDimension);
0190 
0191     // NOTE: both Euclidean distance and cosine distance can help to avoid error in similarity prediction
0192 
0193     if ((sqrDistanceToCurrentNode < sqRange) &&
0194         (cosDistance(position.ptr<float>(), d->position.ptr<float>(), d->nbDimension) > cosThreshold))
0195     {
0196         neighborList[sqrDistanceToCurrentNode].append(d->identity);
0197 
0198         // limit the size of the Map to maxNbNeighbors
0199 
0200         int size = 0;
0201 
0202         for (QMap<double, QVector<int> >::const_iterator iter  = neighborList.cbegin();
0203                                                          iter != neighborList.cend();
0204                                                          ++iter)
0205         {
0206             size += iter.value().size();
0207         }
0208 
0209         if (size > maxNbNeighbors)
0210         {
0211             // Eliminate the farthest neighbor
0212 
0213             QMap<double, QVector<int> >::iterator farthestNodes = std::prev(neighborList.end(), 1);
0214 
0215             if (farthestNodes.value().size() == 1)
0216             {
0217                 neighborList.erase(farthestNodes);
0218             }
0219             else
0220             {
0221                 farthestNodes.value().pop_back();
0222             }
0223 
0224             // update the searching range
0225 
0226             sqRange = neighborList.lastKey();
0227         }
0228     }
0229 
0230     // sub-trees Traversal
0231     // NOTE: DBL_MAX helps avoiding accessing nullptr
0232 
0233     double sqrDistanceleftTree  = 0.0;
0234 
0235     if (d->left == nullptr)
0236     {
0237         sqrDistanceleftTree = DBL_MAX;
0238     }
0239     else
0240     {
0241         const float* const minRange = d->left->d->minRange.ptr<float>();
0242         const float* const maxRange = d->left->d->maxRange.ptr<float>();
0243         const float* const pos      = position.ptr<float>();
0244 
0245         for (int i = 0 ; i < d->nbDimension ; ++i)
0246         {
0247             sqrDistanceleftTree += (pow(qMax((minRange[i] - pos[i]), 0.0f), 2) +
0248                                     pow(qMax((pos[i] - maxRange[i]), 0.0f), 2));
0249         }
0250     }
0251 
0252     double sqrDistancerightTree = 0.0;
0253 
0254     if (d->right == nullptr)
0255     {
0256         sqrDistancerightTree = DBL_MAX;
0257     }
0258     else
0259     {
0260         const float* const minRange = d->right->d->minRange.ptr<float>();
0261         const float* const maxRange = d->right->d->maxRange.ptr<float>();
0262         const float* const pos      = position.ptr<float>();
0263 
0264         for (int i = 0 ; i < d->nbDimension ; ++i)
0265         {
0266             sqrDistancerightTree += (pow(qMax((minRange[i] - pos[i]), 0.0f), 2) +
0267                                      pow(qMax((pos[i] - maxRange[i]), 0.0f), 2));
0268         }
0269     }
0270 
0271     // traverse the closest area
0272 
0273     if (sqrDistanceleftTree < sqrDistancerightTree)
0274     {
0275         if (sqrDistanceleftTree < sqRange)
0276         {
0277             // traverse left Tree
0278 
0279             if (d->left)
0280             {
0281                 sqRange = d->left->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0282 
0283                 if (sqrDistancerightTree < sqRange)
0284                 {
0285                     // traverse right Tree
0286 
0287                     if (d->right)
0288                     {
0289                         sqRange = d->right->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0290                     }
0291                 }
0292             }
0293         }
0294     }
0295     else
0296     {
0297         if (sqrDistancerightTree < sqRange)
0298         {
0299             // traverse right Tree
0300 
0301             if (d->right)
0302             {
0303                 sqRange = d->right->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0304 
0305                 if (sqrDistanceleftTree < sqRange)
0306                 {
0307                     // traverse left Tree
0308 
0309                     if (d->left)
0310                     {
0311                         sqRange = d->left->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0312                     }
0313                 }
0314             }
0315         }
0316     }
0317 
0318     return sqRange;
0319 }
0320 
0321 void KDNode::updateRange(const cv::Mat& pos)
0322 {
0323     if (!((pos.rows   == 1)              &&
0324           (pos.cols   == d->nbDimension) &&
0325           (pos.type() == CV_32F)))
0326     {
0327         return;
0328     }
0329 
0330     float* minRange       = d->minRange.ptr<float>();
0331     float* maxRange       = d->maxRange.ptr<float>();
0332     const float* position = pos.ptr<float>();
0333 
0334     for (int i = 0 ; i < d->nbDimension ; ++i)
0335     {
0336         maxRange[i] = std::max(maxRange[i], position[i]);
0337         minRange[i] = std::min(minRange[i], position[i]);
0338     }
0339 }
0340 
0341 KDNode* KDNode::findParent(const cv::Mat& nodePos)
0342 {
0343     KDNode* parent      = nullptr;
0344     KDNode* currentNode = this;
0345 
0346     while (currentNode != nullptr)
0347     {
0348         currentNode->updateRange(nodePos);
0349 
0350         int split       = currentNode->d->splitAxis;
0351         parent          = currentNode;
0352 
0353         if (nodePos.at<float>(0, split) >= currentNode->d->position.at<float>(0, split))
0354         {
0355             currentNode = currentNode->d->right;
0356         }
0357         else
0358         {
0359             currentNode = currentNode->d->left;
0360         }
0361     }
0362 
0363     return parent;
0364 }
0365 
0366 } // namespace Digikam