File indexing completed on 2025-03-09 03:55:01
0001 /* ============================================================ 0002 * 0003 * This file is a part of digiKam 0004 * 0005 * Date : 2019-06-08 0006 * Description : Node of KD-Tree for vector space partitioning 0007 * 0008 * SPDX-FileCopyrightText: 2020 by Nghia Duong <minhnghiaduong997 at gmail dot com> 0009 * 0010 * SPDX-License-Identifier: GPL-2.0-or-later 0011 * 0012 * ============================================================ */ 0013 0014 #include "kd_node.h" 0015 0016 // C++ include 0017 0018 #include <cfloat> 0019 #include <cstdio> 0020 #include <iterator> 0021 0022 // Qt include 0023 0024 #include <QtMath> 0025 #include <QDebug> 0026 0027 // Local includes 0028 0029 #include "dnnfaceextractor.h" 0030 0031 namespace Digikam 0032 { 0033 0034 float KDNode::sqrDistance(const float* const pos1, const float* const pos2, int dimension) 0035 { 0036 if (!pos1 || !pos2) 0037 { 0038 return 0.0F; 0039 } 0040 0041 double sqrDistance = 0.0; 0042 0043 for (int i = 0 ; i < dimension ; ++i) 0044 { 0045 sqrDistance += pow((pos1[i] - pos2[i]), 2); 0046 } 0047 0048 return float(sqrDistance); 0049 } 0050 0051 float KDNode::cosDistance(const float* const pos1, const float* const pos2, int dimension) 0052 { 0053 if (!pos1 || !pos2) 0054 { 0055 return 0.0F; 0056 } 0057 0058 double scalarProduct = 0.0; 0059 double normV1 = 0.0; 0060 double normV2 = 0.0; 0061 0062 for (int i = 0 ; i < dimension ; ++i) 0063 { 0064 scalarProduct += pos1[i] * pos2[i]; 0065 normV1 += pow(pos1[i], 2); 0066 normV2 += pow(pos2[i], 2); 0067 } 0068 0069 return float(scalarProduct / (normV1 * normV2)); 0070 } 0071 0072 // ---------------------------------------------------------------------------------------- 0073 0074 class Q_DECL_HIDDEN KDNode::Private 0075 { 0076 public: 0077 0078 Private(const cv::Mat& nodePos, const int identity, int splitAxis, int dimension) 0079 : nodeID (-1), 0080 identity (identity), 0081 splitAxis (splitAxis), 0082 nbDimension (dimension), 0083 position (nodePos.clone()), 0084 maxRange (nodePos.clone()), 0085 minRange (nodePos.clone()), 0086 parent (nullptr), 0087 left (nullptr), 0088 right (nullptr) 0089 { 0090 } 0091 0092 ~Private() 0093 { 0094 delete left; 0095 delete right; 0096 } 0097 0098 public: 0099 0100 int nodeID; 0101 int identity; 0102 int splitAxis; 0103 int nbDimension; 0104 0105 cv::Mat position; 0106 cv::Mat maxRange; 0107 cv::Mat minRange; 0108 KDNode* parent; 0109 KDNode* left; 0110 KDNode* right; 0111 }; 0112 0113 KDNode::KDNode(const cv::Mat& nodePos, 0114 const int identity, 0115 int splitAxis, 0116 int dimension) 0117 : d(new Private(nodePos, identity, splitAxis, dimension)) 0118 { 0119 Q_ASSERT(splitAxis < dimension); 0120 Q_ASSERT((nodePos.rows == 1) && 0121 (nodePos.cols == dimension) && 0122 (nodePos.type() == CV_32F)); 0123 } 0124 0125 KDNode::~KDNode() 0126 { 0127 delete d; 0128 } 0129 0130 KDNode* KDNode::insert(const cv::Mat& nodePos, const int identity) 0131 { 0132 if (!((nodePos.rows == 1) && 0133 (nodePos.cols == d->nbDimension) && 0134 (nodePos.type() == CV_32F))) 0135 { 0136 return nullptr; 0137 } 0138 0139 KDNode* const parent = findParent(nodePos); 0140 0141 KDNode* const newNode = new KDNode(nodePos, identity, 0142 ((parent->d->splitAxis + 1) % d->nbDimension), 0143 d->nbDimension); 0144 newNode->d->parent = parent; 0145 /* 0146 qCDebug(DIGIKAM_FACESENGINE_LOG) << "parent embedding" << parent->getPosition() << std::endl; 0147 qCDebug(DIGIKAM_FACESENGINE_LOG) << "node embedding" << nodePos << std::endl; 0148 */ 0149 if (nodePos.at<float>(0, parent->d->splitAxis) >= parent->getPosition().at<float>(0, parent->d->splitAxis)) 0150 { 0151 parent->d->right = newNode; 0152 } 0153 else 0154 { 0155 parent->d->left = newNode; 0156 } 0157 0158 return newNode; 0159 } 0160 0161 cv::Mat KDNode::getPosition() const 0162 { 0163 return d->position; 0164 } 0165 0166 int KDNode::getIdentity() 0167 { 0168 return d->identity; 0169 } 0170 0171 void KDNode::setNodeId(int id) 0172 { 0173 d->nodeID = id; 0174 } 0175 0176 double KDNode::getClosestNeighbors(QMap<double, QVector<int> >& neighborList, 0177 const cv::Mat& position, 0178 float sqRange, 0179 float cosThreshold, 0180 int maxNbNeighbors) const 0181 { 0182 if (!position.ptr<float>()) 0183 { 0184 return sqRange; 0185 } 0186 0187 // add current node to the list 0188 0189 const double sqrDistanceToCurrentNode = sqrDistance(position.ptr<float>(), d->position.ptr<float>(), d->nbDimension); 0190 0191 // NOTE: both Euclidean distance and cosine distance can help to avoid error in similarity prediction 0192 0193 if ((sqrDistanceToCurrentNode < sqRange) && 0194 (cosDistance(position.ptr<float>(), d->position.ptr<float>(), d->nbDimension) > cosThreshold)) 0195 { 0196 neighborList[sqrDistanceToCurrentNode].append(d->identity); 0197 0198 // limit the size of the Map to maxNbNeighbors 0199 0200 int size = 0; 0201 0202 for (QMap<double, QVector<int> >::const_iterator iter = neighborList.cbegin(); 0203 iter != neighborList.cend(); 0204 ++iter) 0205 { 0206 size += iter.value().size(); 0207 } 0208 0209 if (size > maxNbNeighbors) 0210 { 0211 // Eliminate the farthest neighbor 0212 0213 QMap<double, QVector<int> >::iterator farthestNodes = std::prev(neighborList.end(), 1); 0214 0215 if (farthestNodes.value().size() == 1) 0216 { 0217 neighborList.erase(farthestNodes); 0218 } 0219 else 0220 { 0221 farthestNodes.value().pop_back(); 0222 } 0223 0224 // update the searching range 0225 0226 sqRange = neighborList.lastKey(); 0227 } 0228 } 0229 0230 // sub-trees Traversal 0231 // NOTE: DBL_MAX helps avoiding accessing nullptr 0232 0233 double sqrDistanceleftTree = 0.0; 0234 0235 if (d->left == nullptr) 0236 { 0237 sqrDistanceleftTree = DBL_MAX; 0238 } 0239 else 0240 { 0241 const float* const minRange = d->left->d->minRange.ptr<float>(); 0242 const float* const maxRange = d->left->d->maxRange.ptr<float>(); 0243 const float* const pos = position.ptr<float>(); 0244 0245 for (int i = 0 ; i < d->nbDimension ; ++i) 0246 { 0247 sqrDistanceleftTree += (pow(qMax((minRange[i] - pos[i]), 0.0f), 2) + 0248 pow(qMax((pos[i] - maxRange[i]), 0.0f), 2)); 0249 } 0250 } 0251 0252 double sqrDistancerightTree = 0.0; 0253 0254 if (d->right == nullptr) 0255 { 0256 sqrDistancerightTree = DBL_MAX; 0257 } 0258 else 0259 { 0260 const float* const minRange = d->right->d->minRange.ptr<float>(); 0261 const float* const maxRange = d->right->d->maxRange.ptr<float>(); 0262 const float* const pos = position.ptr<float>(); 0263 0264 for (int i = 0 ; i < d->nbDimension ; ++i) 0265 { 0266 sqrDistancerightTree += (pow(qMax((minRange[i] - pos[i]), 0.0f), 2) + 0267 pow(qMax((pos[i] - maxRange[i]), 0.0f), 2)); 0268 } 0269 } 0270 0271 // traverse the closest area 0272 0273 if (sqrDistanceleftTree < sqrDistancerightTree) 0274 { 0275 if (sqrDistanceleftTree < sqRange) 0276 { 0277 // traverse left Tree 0278 0279 if (d->left) 0280 { 0281 sqRange = d->left->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0282 0283 if (sqrDistancerightTree < sqRange) 0284 { 0285 // traverse right Tree 0286 0287 if (d->right) 0288 { 0289 sqRange = d->right->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0290 } 0291 } 0292 } 0293 } 0294 } 0295 else 0296 { 0297 if (sqrDistancerightTree < sqRange) 0298 { 0299 // traverse right Tree 0300 0301 if (d->right) 0302 { 0303 sqRange = d->right->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0304 0305 if (sqrDistanceleftTree < sqRange) 0306 { 0307 // traverse left Tree 0308 0309 if (d->left) 0310 { 0311 sqRange = d->left->getClosestNeighbors(neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0312 } 0313 } 0314 } 0315 } 0316 } 0317 0318 return sqRange; 0319 } 0320 0321 void KDNode::updateRange(const cv::Mat& pos) 0322 { 0323 if (!((pos.rows == 1) && 0324 (pos.cols == d->nbDimension) && 0325 (pos.type() == CV_32F))) 0326 { 0327 return; 0328 } 0329 0330 float* minRange = d->minRange.ptr<float>(); 0331 float* maxRange = d->maxRange.ptr<float>(); 0332 const float* position = pos.ptr<float>(); 0333 0334 for (int i = 0 ; i < d->nbDimension ; ++i) 0335 { 0336 maxRange[i] = std::max(maxRange[i], position[i]); 0337 minRange[i] = std::min(minRange[i], position[i]); 0338 } 0339 } 0340 0341 KDNode* KDNode::findParent(const cv::Mat& nodePos) 0342 { 0343 KDNode* parent = nullptr; 0344 KDNode* currentNode = this; 0345 0346 while (currentNode != nullptr) 0347 { 0348 currentNode->updateRange(nodePos); 0349 0350 int split = currentNode->d->splitAxis; 0351 parent = currentNode; 0352 0353 if (nodePos.at<float>(0, split) >= currentNode->d->position.at<float>(0, split)) 0354 { 0355 currentNode = currentNode->d->right; 0356 } 0357 else 0358 { 0359 currentNode = currentNode->d->left; 0360 } 0361 } 0362 0363 return parent; 0364 } 0365 0366 } // namespace Digikam