File indexing completed on 2025-03-09 03:54:59

0001 /* ============================================================
0002  *
0003  * This file is a part of digiKam
0004  *
0005  * Date        : 02-02-2012
0006  * Description : Face database interface for spatial storage of face embedding.
0007  *
0008  * SPDX-FileCopyrightText: 2012-2013 by Marcel Wiesweg <marcel dot wiesweg at gmx dot de>
0009  * SPDX-FileCopyrightText: 2010-2024 by Gilles Caulier <caulier dot gilles at gmail dot com>
0010  * SPDX-FileCopyrightText:      2019 by Thanh Trung Dinh <dinhthanhtrung1996 at gmail dot com>
0011  * SPDX-FileCopyrightText:      2020 by Nghia Duong <minhnghiaduong997 at gmail dot com>
0012  *
0013  * SPDX-License-Identifier: GPL-2.0-or-later
0014  *
0015  * ============================================================ */
0016 
0017 #include "facedb_p.h"
0018 
0019 namespace Digikam
0020 {
0021 
0022 class FaceDb::DataNode
0023 {
0024 public:
0025 
0026     explicit DataNode()
0027         : nodeID    (0),
0028           label     (0),
0029           splitAxis (0),
0030           left      (-1),
0031           right     (-1)
0032     {
0033     }
0034 
0035     bool isNull() const
0036     {
0037         return (nodeID == 0);
0038     }
0039 
0040 public:
0041 
0042     int     nodeID;
0043     int     label;
0044     int     splitAxis;
0045     int     left;
0046     int     right;
0047     cv::Mat position;
0048     cv::Mat minRange;
0049     cv::Mat maxRange;
0050 };
0051 
0052 bool FaceDb::insertToTreeDb(const int nodeID, const cv::Mat& faceEmbedding) const
0053 {
0054     bool isLeftChild    = false;
0055     int parentSplitAxis = 0;
0056     int parentID        = findParentTreeDb(faceEmbedding, isLeftChild, parentSplitAxis);
0057 
0058     if (parentID < 0)
0059     {
0060         qCWarning(DIGIKAM_FACEDB_LOG) << "fail to find parent node";
0061         return false;
0062     }
0063 
0064     QVariantList bindingValues;
0065 
0066     bindingValues << (parentSplitAxis + 1) % 128;
0067     bindingValues << nodeID;
0068     bindingValues << QByteArray::fromRawData((char*)faceEmbedding.ptr<float>(), (sizeof(float) * 128));
0069     bindingValues << QByteArray::fromRawData((char*)faceEmbedding.ptr<float>(), (sizeof(float) * 128));
0070     bindingValues << parentID;
0071 
0072     // insert node to database
0073 
0074     DbEngineSqlQuery query = d->db->execQuery(QLatin1String("INSERT INTO KDTree "
0075                                                             "(split_axis, position, max_range, min_range, parent, `left`, `right`) "
0076                                                             "VALUES (?, ?, ?, ?, ?, NULL, NULL)"),
0077                                               bindingValues);
0078 
0079     int newNode            = query.lastInsertId().toInt();
0080 
0081     if (newNode <= 0)
0082     {
0083         qCWarning(DIGIKAM_FACEDB_LOG) << "error insert into treedb" << query.lastError();
0084     }
0085 
0086     if (parentID > 0)
0087     {
0088         bindingValues.clear();
0089         bindingValues << newNode;
0090         bindingValues << parentID;
0091 
0092         // not root -> update parent
0093 
0094         if (isLeftChild)
0095         {
0096             query = d->db->execQuery(QLatin1String("UPDATE KDTree SET left = ? WHERE id = ?;"), bindingValues);
0097         }
0098         else
0099         {
0100             query = d->db->execQuery(QLatin1String("UPDATE KDTree SET right = ? WHERE id = ?;"), bindingValues);
0101         }
0102     }
0103 
0104     return true;
0105 }
0106 
0107 QMap<double, QVector<int> > FaceDb::getClosestNeighborsTreeDb(const cv::Mat& position,
0108                                                               float sqRange,
0109                                                               float cosThreshold,
0110                                                               int maxNbNeighbors) const
0111 {
0112     QMap<double, QVector<int> > closestNeighbors;
0113 
0114     DataNode root;
0115 
0116     DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT position, max_range, min_range, `left`, `right` "
0117                                                             "FROM KDTree WHERE id = 1"));
0118     if (query.next())
0119     {
0120         // encapsulate data node
0121 
0122         root.nodeID     = 1;
0123         int embeddingID = query.value(0).toInt();
0124         root.maxRange   = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone();
0125         root.minRange   = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone();
0126         root.left       = query.value(3).toInt();
0127         root.right      = query.value(4).toInt();
0128 
0129         QVariantList bindingValues;
0130         bindingValues << embeddingID;
0131 
0132         query = d->db->execQuery(QLatin1String("SELECT identity, embedding FROM FaceMatrices WHERE id = ?"),
0133                                  bindingValues);
0134 
0135         if (query.next())
0136         {
0137             root.label    = query.value(0).toInt();
0138             root.position = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone();
0139         }
0140 
0141         getClosestNeighborsTreeDb(root, closestNeighbors, position, sqRange, cosThreshold, maxNbNeighbors);
0142     }
0143 
0144     return closestNeighbors;
0145 }
0146 
0147 void FaceDb::clearTreeDb() const
0148 {
0149     d->db->execSql(QLatin1String("DELETE FROM KDTree;"));
0150 }
0151 
0152 void FaceDb::updateRangeTreeDb(int nodeId, cv::Mat& minRange, cv::Mat& maxRange, const cv::Mat& position) const
0153 {
0154     float* const min = minRange.ptr<float>();
0155     float* const max = maxRange.ptr<float>();
0156     const float* pos = position.ptr<float>();
0157 
0158     for (int i = 0 ; i < position.cols ; ++i)
0159     {
0160         max[i] = std::max(max[i], pos[i]);
0161         min[i] = std::min(min[i], pos[i]);
0162     }
0163 
0164     QVariantList bindingValues;
0165 
0166     bindingValues << QByteArray::fromRawData((char*)max, (sizeof(float) * 128));
0167     bindingValues << QByteArray::fromRawData((char*)min, (sizeof(float) * 128));
0168     bindingValues << nodeId;
0169 
0170     DbEngineSqlQuery query = d->db->execQuery(QLatin1String("UPDATE KDTree SET max_range = ?, min_range = ? WHERE id = ?;"),
0171                                               bindingValues);
0172 }
0173 
0174 int FaceDb::findParentTreeDb(const cv::Mat& nodePos, bool& leftChild, int& parentSplitAxis) const
0175 {
0176     int parent           = 1;
0177     QVariant currentNode = parent;
0178 
0179     while (currentNode.isValid() && !currentNode.isNull())
0180     {
0181         parent = currentNode.toInt();
0182 
0183         QVariantList bindingValues;
0184         bindingValues << parent;
0185 
0186         DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT split_axis, position, max_range, min_range, `left`, `right` "
0187                                                                 "FROM KDTree WHERE id = ?"),
0188                                                   bindingValues);
0189 
0190         if (! query.next())
0191         {
0192             if (parent == 1)
0193             {
0194 /*
0195                 qCDebug(DIGIKAM_FACEDB_LOG) << "add root";
0196 */
0197                 return 0;
0198             }
0199 
0200             qCWarning(DIGIKAM_FACEDB_LOG) << "Error query parent =" << parent << query.lastError();
0201             return -1;
0202         }
0203 
0204 /*
0205         qCDebug(DIGIKAM_FACEDB_LOG) << "split axis" << query.value(0).toInt()
0206                                     << "left"       << query.value(4)
0207                                     << "right"      << query.value(5);
0208 */
0209 
0210         int split        = query.value(0).toInt();
0211         parentSplitAxis  = split;
0212 
0213         int embeddingId  = query.value(1).toInt();
0214         cv::Mat maxRange = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone();
0215         cv::Mat minRange = cv::Mat(1, 128, CV_32F, query.value(3).toByteArray().data()).clone();
0216         QVariant left    = query.value(4);
0217         QVariant right   = query.value(5);
0218 
0219         bindingValues.clear();
0220         bindingValues << embeddingId;
0221 
0222         query = d->db->execQuery(QLatin1String("SELECT embedding FROM FaceMatrices WHERE id = ?"),
0223                                  bindingValues);
0224 
0225         if (! query.next())
0226         {
0227             qCWarning(DIGIKAM_FACEDB_LOG) << "fail to find parent face embedding" << query.lastError();
0228             return -1;
0229         }
0230 
0231         cv::Mat position = cv::Mat(1, 128, CV_32F, query.value(0).toByteArray().data()).clone();
0232 
0233         if (nodePos.at<float>(0, split) >= position.at<float>(0, split))
0234         {
0235             currentNode = right;
0236             leftChild   = false;
0237         }
0238         else
0239         {
0240             currentNode = left;
0241             leftChild   = true;
0242         }
0243 
0244          updateRangeTreeDb(parent, minRange, maxRange, nodePos);
0245     }
0246 
0247     return parent;
0248 }
0249 
0250 double FaceDb::getClosestNeighborsTreeDb(const DataNode& subTree,
0251                                          QMap<double, QVector<int> >& neighborList,
0252                                          const cv::Mat& position,
0253                                          float sqRange,
0254                                          float cosThreshold,
0255                                          int maxNbNeighbors) const
0256 {
0257     if (subTree.isNull())
0258     {
0259         return sqRange;
0260     }
0261 
0262     // try to add current node to the list
0263 
0264     const float sqrdistanceToCurrentNode = KDNode::sqrDistance(position.ptr<float>(), subTree.position.ptr<float>(), 128);
0265 
0266     if ((sqrdistanceToCurrentNode < sqRange) &&
0267         (KDNode::cosDistance(position.ptr<float>(), subTree.position.ptr<float>(), 128) > cosThreshold))
0268     {
0269         neighborList[sqrdistanceToCurrentNode].append(subTree.label);
0270 
0271         // limit the size of the Map to maxNbNeighbors
0272 
0273         int size = 0;
0274 
0275         for (QMap<double, QVector<int> >::const_iterator iter  = neighborList.cbegin();
0276                                                          iter != neighborList.cend();
0277                                                          ++iter)
0278         {
0279             size += iter.value().size();
0280         }
0281 
0282         if (size > maxNbNeighbors)
0283         {
0284             // Eliminate the farthest neighbor
0285 
0286             QMap<double, QVector<int> >::iterator farthestNodes = std::prev(neighborList.end(), 1);
0287 
0288             if (farthestNodes.value().size() == 1)
0289             {
0290                 neighborList.erase(farthestNodes);
0291             }
0292             else
0293             {
0294                 farthestNodes.value().pop_back();
0295             }
0296 
0297             // update the searching range
0298 
0299             sqRange = neighborList.lastKey();
0300         }
0301     }
0302 
0303     // sub-trees Traversal
0304 
0305     double sqrDistanceLeftTree  = 0.0;
0306     double sqrDistanceRightTree = 0.0;
0307     DataNode leftNode;
0308     DataNode rightNode;
0309 
0310     if (subTree.left <= 0)
0311     {
0312         sqrDistanceLeftTree = DBL_MAX;
0313     }
0314     else
0315     {
0316         QVariantList bindingValues;
0317         bindingValues << subTree.left;
0318 
0319         DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT position, max_range, min_range, `left`, `right` "
0320                                                                 "FROM KDTree WHERE id = ?"),
0321                                                   bindingValues);
0322 
0323         if (query.next())
0324         {
0325             // encapsulate data node
0326 
0327             leftNode.nodeID   = subTree.left;
0328             int embeddingID   = query.value(0).toInt();
0329             leftNode.maxRange = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone();
0330             leftNode.minRange = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone();
0331             leftNode.left     = query.value(3).toInt();
0332             leftNode.right    = query.value(4).toInt();
0333 
0334             bindingValues.clear();
0335             bindingValues << embeddingID;
0336 
0337             query = d->db->execQuery(QLatin1String("SELECT identity, embedding FROM FaceMatrices WHERE id = ?"),
0338                                      bindingValues);
0339 
0340             if (query.next())
0341             {
0342                 leftNode.label        = query.value(0).toInt();
0343                 leftNode.position     = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone();
0344 
0345                 const float* minRange = leftNode.minRange.ptr<float>();
0346                 const float* maxRange = leftNode.maxRange.ptr<float>();
0347                 const float* pos      = leftNode.position.ptr<float>();
0348 
0349                 for (int i = 0 ; i < 128 ; ++i)
0350                 {
0351                     sqrDistanceLeftTree += (pow(qMax((minRange[i] - pos[i]),      0.0f), 2) +
0352                                             pow(qMax((pos[i]      - maxRange[i]), 0.0f), 2));
0353                 }
0354             }
0355         }
0356     }
0357 
0358     if (subTree.right <= 0)
0359     {
0360         sqrDistanceRightTree = DBL_MAX;
0361     }
0362     else
0363     {
0364         QVariantList bindingValues;
0365         bindingValues << subTree.right;
0366 
0367         DbEngineSqlQuery query = d->db->execQuery(QLatin1String("SELECT position, max_range, min_range, `left`, `right` "
0368                                                                 "FROM KDTree WHERE id = ?"),
0369                                                   bindingValues);
0370 
0371         if (query.next())
0372         {
0373             // encapsulate data node
0374 
0375             rightNode.nodeID   = subTree.right;
0376             int embeddingID    = query.value(0).toInt();
0377             rightNode.maxRange = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone();
0378             rightNode.minRange = cv::Mat(1, 128, CV_32F, query.value(2).toByteArray().data()).clone();
0379             rightNode.left     = query.value(3).toInt();
0380             rightNode.right    = query.value(4).toInt();
0381 
0382             bindingValues.clear();
0383             bindingValues << embeddingID;
0384 
0385             query = d->db->execQuery(QLatin1String("SELECT identity, embedding FROM FaceMatrices WHERE id = ?"),
0386                                      bindingValues);
0387 
0388             if (query.next())
0389             {
0390                 rightNode.label             = query.value(0).toInt();
0391                 rightNode.position          = cv::Mat(1, 128, CV_32F, query.value(1).toByteArray().data()).clone();
0392 
0393                 const float* const minRange = rightNode.minRange.ptr<float>();
0394                 const float* const maxRange = rightNode.maxRange.ptr<float>();
0395                 const float* const pos      = rightNode.position.ptr<float>();
0396 
0397                 for (int i = 0 ; i < 128 ; ++i)
0398                 {
0399                     sqrDistanceRightTree += (pow(qMax((minRange[i] - pos[i]),      0.0f), 2) +
0400                                              pow(qMax((pos[i]      - maxRange[i]), 0.0f), 2));
0401                 }
0402             }
0403         }
0404     }
0405 
0406     // traverse the closest area
0407 
0408     if (sqrDistanceLeftTree < sqrDistanceRightTree)
0409     {
0410         if (sqrDistanceLeftTree < sqRange)
0411         {
0412             // traverse left Tree
0413 
0414             sqRange = getClosestNeighborsTreeDb(leftNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0415 
0416             if (sqrDistanceRightTree < sqRange)
0417             {
0418                 // traverse right Tree
0419 
0420                 sqRange = getClosestNeighborsTreeDb(rightNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0421             }
0422         }
0423     }
0424     else
0425     {
0426         if (sqrDistanceRightTree < sqRange)
0427         {
0428             // traverse right Tree
0429 
0430             sqRange = getClosestNeighborsTreeDb(rightNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0431 
0432             if (sqrDistanceLeftTree < sqRange)
0433             {
0434                 // traverse left Tree
0435 
0436                 sqRange = getClosestNeighborsTreeDb(leftNode, neighborList, position, sqRange, cosThreshold, maxNbNeighbors);
0437             }
0438         }
0439     }
0440 
0441     return sqRange;
0442 }
0443 
0444 } // namespace Digikam