File indexing completed on 2025-03-09 03:54:59
0001 /* ============================================================ 0002 * 0003 * This file is a part of digiKam 0004 * 0005 * Date : 02-02-2012 0006 * Description : Face database interface for spatial storage of face embedding. 0007 * 0008 * SPDX-FileCopyrightText: 2012-2013 by Marcel Wiesweg <marcel dot wiesweg at gmx dot de> 0009 * SPDX-FileCopyrightText: 2010-2024 by Gilles Caulier <caulier dot gilles at gmail dot com> 0010 * SPDX-FileCopyrightText: 2019 by Thanh Trung Dinh <dinhthanhtrung1996 at gmail dot com> 0011 * SPDX-FileCopyrightText: 2020 by Nghia Duong <minhnghiaduong997 at gmail dot com> 0012 * 0013 * SPDX-License-Identifier: GPL-2.0-or-later 0014 * 0015 * ============================================================ */ 0016 0017 #include "facedb_p.h" 0018 0019 namespace Digikam 0020 { 0021 0022 class FaceDb::DataNode 0023 { 0024 public: 0025 0026 explicit DataNode() 0027 : nodeID (0), 0028 label (0), 0029 splitAxis (0), 0030 left (-1), 0031 right (-1) 0032 { 0033 } 0034 0035 bool isNull() const 0036 { 0037 return (nodeID == 0); 0038 } 0039 0040 public: 0041 0042 int nodeID; 0043 int label; 0044 int splitAxis; 0045 int left; 0046 int right; 0047 cv::Mat position; 0048 cv::Mat minRange; 0049 cv::Mat maxRange; 0050 }; 0051 0052 bool FaceDb::insertToTreeDb(const int nodeID, const cv::Mat& faceEmbedding) const 0053 { 0054 bool isLeftChild = false; 0055 int parentSplitAxis = 0; 0056 int parentID = findParentTreeDb(faceEmbedding, isLeftChild, parentSplitAxis); 0057 0058 if (parentID < 0) 0059 { 0060 qCWarning(DIGIKAM_FACEDB_LOG) << "fail to find parent node"; 0061 return false; 0062 } 0063 0064 QVariantList bindingValues; 0065 0066 bindingValues << (parentSplitAxis + 1) % 128; 0067 bindingValues << nodeID; 0068 bindingValues << QByteArray::fromRawData((char*)faceEmbedding.ptr<float>(), (sizeof(float) * 128)); 0069 bindingValues << QByteArray::fromRawData((char*)faceEmbedding.ptr<float>(), (sizeof(float) * 128)); 0070 bindingValues << parentID; 0071 0072 // insert node to database 0073 0074 DbEngineSqlQuery query = d->db->execQuery(QLatin1String("INSERT INTO KDTree " 0075 "(split_axis, position, max_range, min_range, parent, `left`, `right`) " 0076 "VALUES (?, ?, ?, ?, ?, NULL, NULL)"), 0077 bindingValues); 0078 0079 int newNode = query.lastInsertId().toInt(); 0080 0081 if (newNode <= 0) 0082 { 0083 qCWarning(DIGIKAM_FACEDB_LOG) << "error insert into treedb" << query.lastError(); 0084 } 0085 0086 if (parentID > 0) 0087 { 0088 bindingValues.clear(); 0089 bindingValues << newNode; 0090 bindingValues << parentID; 0091 0092 // not root -> update parent 0093 0094 if (isLeftChild) 0095 { 0096 query = d->db->execQuery(QLatin1String("UPDATE KDTree SET left = ? WHERE id = ?;"), bindingValues); 0097 } 0098 else 0099 { 0100 query = d->db->execQuery(QLatin1String("UPDATE KDTree SET right = ? WHERE id = ?;"), bindingValues); 0101 } 0102 } 0103 0104 return true; 0105 } 0106 0107 QMap<double, QVector<int> > FaceDb::getClosestNeighborsTreeDb(const cv::Mat& position, 0108 float sqRange, 0109 float cosThreshold, 0110 int maxNbNeighbors) const 0111 { 0112 QMap<double, QVector<int> > closestNeighbors; 0113 0114 DataNode root; 0115 0116 DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT position, max_range, min_range, `left`, `right` " 0117 "FROM KDTree WHERE id = 1")); 0118 if (query.next()) 0119 { 0120 // encapsulate data node 0121 0122 root.nodeID = 1; 0123 int embeddingID = query.value(0).toInt(); 0124 root.maxRange = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone(); 0125 root.minRange = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone(); 0126 root.left = query.value(3).toInt(); 0127 root.right = query.value(4).toInt(); 0128 0129 QVariantList bindingValues; 0130 bindingValues << embeddingID; 0131 0132 query = d->db->execQuery(QLatin1String("SELECT identity, embedding FROM FaceMatrices WHERE id = ?"), 0133 bindingValues); 0134 0135 if (query.next()) 0136 { 0137 root.label = query.value(0).toInt(); 0138 root.position = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone(); 0139 } 0140 0141 getClosestNeighborsTreeDb(root, closestNeighbors, position, sqRange, cosThreshold, maxNbNeighbors); 0142 } 0143 0144 return closestNeighbors; 0145 } 0146 0147 void FaceDb::clearTreeDb() const 0148 { 0149 d->db->execSql(QLatin1String("DELETE FROM KDTree;")); 0150 } 0151 0152 void FaceDb::updateRangeTreeDb(int nodeId, cv::Mat& minRange, cv::Mat& maxRange, const cv::Mat& position) const 0153 { 0154 float* const min = minRange.ptr<float>(); 0155 float* const max = maxRange.ptr<float>(); 0156 const float* pos = position.ptr<float>(); 0157 0158 for (int i = 0 ; i < position.cols ; ++i) 0159 { 0160 max[i] = std::max(max[i], pos[i]); 0161 min[i] = std::min(min[i], pos[i]); 0162 } 0163 0164 QVariantList bindingValues; 0165 0166 bindingValues << QByteArray::fromRawData((char*)max, (sizeof(float) * 128)); 0167 bindingValues << QByteArray::fromRawData((char*)min, (sizeof(float) * 128)); 0168 bindingValues << nodeId; 0169 0170 DbEngineSqlQuery query = d->db->execQuery(QLatin1String("UPDATE KDTree SET max_range = ?, min_range = ? WHERE id = ?;"), 0171 bindingValues); 0172 } 0173 0174 int FaceDb::findParentTreeDb(const cv::Mat& nodePos, bool& leftChild, int& parentSplitAxis) const 0175 { 0176 int parent = 1; 0177 QVariant currentNode = parent; 0178 0179 while (currentNode.isValid() && !currentNode.isNull()) 0180 { 0181 parent = currentNode.toInt(); 0182 0183 QVariantList bindingValues; 0184 bindingValues << parent; 0185 0186 DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT split_axis, position, max_range, min_range, `left`, `right` " 0187 "FROM KDTree WHERE id = ?"), 0188 bindingValues); 0189 0190 if (! query.next()) 0191 { 0192 if (parent == 1) 0193 { 0194 /* 0195 qCDebug(DIGIKAM_FACEDB_LOG) << "add root"; 0196 */ 0197 return 0; 0198 } 0199 0200 qCWarning(DIGIKAM_FACEDB_LOG) << "Error query parent =" << parent << query.lastError(); 0201 return -1; 0202 } 0203 0204 /* 0205 qCDebug(DIGIKAM_FACEDB_LOG) << "split axis" << query.value(0).toInt() 0206 << "left" << query.value(4) 0207 << "right" << query.value(5); 0208 */ 0209 0210 int split = query.value(0).toInt(); 0211 parentSplitAxis = split; 0212 0213 int embeddingId = query.value(1).toInt(); 0214 cv::Mat maxRange = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone(); 0215 cv::Mat minRange = cv::Mat(1, 128, CV_32F, query.value(3).toByteArray().data()).clone(); 0216 QVariant left = query.value(4); 0217 QVariant right = query.value(5); 0218 0219 bindingValues.clear(); 0220 bindingValues << embeddingId; 0221 0222 query = d->db->execQuery(QLatin1String("SELECT embedding FROM FaceMatrices WHERE id = ?"), 0223 bindingValues); 0224 0225 if (! query.next()) 0226 { 0227 qCWarning(DIGIKAM_FACEDB_LOG) << "fail to find parent face embedding" << query.lastError(); 0228 return -1; 0229 } 0230 0231 cv::Mat position = cv::Mat(1, 128, CV_32F, query.value(0).toByteArray().data()).clone(); 0232 0233 if (nodePos.at<float>(0, split) >= position.at<float>(0, split)) 0234 { 0235 currentNode = right; 0236 leftChild = false; 0237 } 0238 else 0239 { 0240 currentNode = left; 0241 leftChild = true; 0242 } 0243 0244 updateRangeTreeDb(parent, minRange, maxRange, nodePos); 0245 } 0246 0247 return parent; 0248 } 0249 0250 double FaceDb::getClosestNeighborsTreeDb(const DataNode& subTree, 0251 QMap<double, QVector<int> >& neighborList, 0252 const cv::Mat& position, 0253 float sqRange, 0254 float cosThreshold, 0255 int maxNbNeighbors) const 0256 { 0257 if (subTree.isNull()) 0258 { 0259 return sqRange; 0260 } 0261 0262 // try to add current node to the list 0263 0264 const float sqrdistanceToCurrentNode = KDNode::sqrDistance(position.ptr<float>(), subTree.position.ptr<float>(), 128); 0265 0266 if ((sqrdistanceToCurrentNode < sqRange) && 0267 (KDNode::cosDistance(position.ptr<float>(), subTree.position.ptr<float>(), 128) > cosThreshold)) 0268 { 0269 neighborList[sqrdistanceToCurrentNode].append(subTree.label); 0270 0271 // limit the size of the Map to maxNbNeighbors 0272 0273 int size = 0; 0274 0275 for (QMap<double, QVector<int> >::const_iterator iter = neighborList.cbegin(); 0276 iter != neighborList.cend(); 0277 ++iter) 0278 { 0279 size += iter.value().size(); 0280 } 0281 0282 if (size > maxNbNeighbors) 0283 { 0284 // Eliminate the farthest neighbor 0285 0286 QMap<double, QVector<int> >::iterator farthestNodes = std::prev(neighborList.end(), 1); 0287 0288 if (farthestNodes.value().size() == 1) 0289 { 0290 neighborList.erase(farthestNodes); 0291 } 0292 else 0293 { 0294 farthestNodes.value().pop_back(); 0295 } 0296 0297 // update the searching range 0298 0299 sqRange = neighborList.lastKey(); 0300 } 0301 } 0302 0303 // sub-trees Traversal 0304 0305 double sqrDistanceLeftTree = 0.0; 0306 double sqrDistanceRightTree = 0.0; 0307 DataNode leftNode; 0308 DataNode rightNode; 0309 0310 if (subTree.left <= 0) 0311 { 0312 sqrDistanceLeftTree = DBL_MAX; 0313 } 0314 else 0315 { 0316 QVariantList bindingValues; 0317 bindingValues << subTree.left; 0318 0319 DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT position, max_range, min_range, `left`, `right` " 0320 "FROM KDTree WHERE id = ?"), 0321 bindingValues); 0322 0323 if (query.next()) 0324 { 0325 // encapsulate data node 0326 0327 leftNode.nodeID = subTree.left; 0328 int embeddingID = query.value(0).toInt(); 0329 leftNode.maxRange = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone(); 0330 leftNode.minRange = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone(); 0331 leftNode.left = query.value(3).toInt(); 0332 leftNode.right = query.value(4).toInt(); 0333 0334 bindingValues.clear(); 0335 bindingValues << embeddingID; 0336 0337 query = d->db->execQuery(QLatin1String("SELECT identity, embedding FROM FaceMatrices WHERE id = ?"), 0338 bindingValues); 0339 0340 if (query.next()) 0341 { 0342 leftNode.label = query.value(0).toInt(); 0343 leftNode.position = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone(); 0344 0345 const float* minRange = leftNode.minRange.ptr<float>(); 0346 const float* maxRange = leftNode.maxRange.ptr<float>(); 0347 const float* pos = leftNode.position.ptr<float>(); 0348 0349 for (int i = 0 ; i < 128 ; ++i) 0350 { 0351 sqrDistanceLeftTree += (pow(qMax((minRange[i] - pos[i]), 0.0f), 2) + 0352 pow(qMax((pos[i] - maxRange[i]), 0.0f), 2)); 0353 } 0354 } 0355 } 0356 } 0357 0358 if (subTree.right <= 0) 0359 { 0360 sqrDistanceRightTree = DBL_MAX; 0361 } 0362 else 0363 { 0364 QVariantList bindingValues; 0365 bindingValues << subTree.right; 0366 0367 DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT position, max_range, min_range, `left`, `right` " 0368 "FROM KDTree WHERE id = ?"), 0369 bindingValues); 0370 0371 if (query.next()) 0372 { 0373 // encapsulate data node 0374 0375 rightNode.nodeID = subTree.right; 0376 int embeddingID = query.value(0).toInt(); 0377 rightNode.maxRange = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone(); 0378 rightNode.minRange = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone(); 0379 rightNode.left = query.value(3).toInt(); 0380 rightNode.right = query.value(4).toInt(); 0381 0382 bindingValues.clear(); 0383 bindingValues << embeddingID; 0384 0385 query = d->db->execQuery(QLatin1String("SELECT identity, embedding FROM FaceMatrices WHERE id = ?"), 0386 bindingValues); 0387 0388 if (query.next()) 0389 { 0390 rightNode.label = query.value(0).toInt(); 0391 rightNode.position = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone(); 0392 0393 const float* const minRange = rightNode.minRange.ptr<float>(); 0394 const float* const maxRange = rightNode.maxRange.ptr<float>(); 0395 const float* const pos = rightNode.position.ptr<float>(); 0396 0397 for (int i = 0 ; i < 128 ; ++i) 0398 { 0399 sqrDistanceRightTree += (pow(qMax((minRange[i] - pos[i]), 0.0f), 2) + 0400 pow(qMax((pos[i] - maxRange[i]), 0.0f), 2)); 0401 } 0402 } 0403 } 0404 } 0405 0406 // traverse the closest area 0407 0408 if (sqrDistanceLeftTree < sqrDistanceRightTree) 0409 { 0410 if (sqrDistanceLeftTree < sqRange) 0411 { 0412 // traverse left Tree 0413 0414 sqRange = getClosestNeighborsTreeDb(leftNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0415 0416 if (sqrDistanceRightTree < sqRange) 0417 { 0418 // traverse right Tree 0419 0420 sqRange = getClosestNeighborsTreeDb(rightNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0421 } 0422 } 0423 } 0424 else 0425 { 0426 if (sqrDistanceRightTree < sqRange) 0427 { 0428 // traverse right Tree 0429 0430 sqRange = getClosestNeighborsTreeDb(rightNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0431 0432 if (sqrDistanceLeftTree < sqRange) 0433 { 0434 // traverse left Tree 0435 0436 sqRange = getClosestNeighborsTreeDb(leftNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors); 0437 } 0438 } 0439 } 0440 0441 return sqRange; 0442 } 0443 0444 } // namespace Digikam