File indexing completed on 2025-01-19 03:57:42

0001 /* ============================================================
0002  *
0003  * This file is a part of digiKam project
0004  * https://www.digikam.org
0005  *
0006  * Date        : 2019-08-10
0007  * Description : CLI tool to test and verify clustering for Face Recognition
0008  *
0009  * SPDX-FileCopyrightText: 2019 by Thanh Trung Dinh <dinhthanhtrung1996 at gmail dot com>
0010  *
0011  * SPDX-License-Identifier: GPL-2.0-or-later
0012  *
0013  * ============================================================ */
0014 
0015 // C++ includes
0016 
0017 #include <set>
0018 
0019 // Qt includes
0020 
0021 #include <QCoreApplication>
0022 #include <QDir>
0023 #include <QImage>
0024 #include <QElapsedTimer>
0025 #include <QCommandLineParser>
0026 #include <QList>
0027 
0028 // Local includes
0029 
0030 #include "digikam_debug.h"
0031 #include "dimg.h"
0032 #include "facescansettings.h"
0033 #include "facedetector.h"
0034 #include "coredbaccess.h"
0035 #include "dbengineparameters.h"
0036 #include "facialrecognition_wrapper.h"
0037 
0038 using namespace Digikam;
0039 
0040 // --------------------------------------------------------------------------------------------------
0041 
0042 /**
0043  * Function to return the
0044  * intersection vector of v1 and v2
0045  */
0046 void intersection(const std::vector<int>& v1,
0047                   const std::vector<int>& v2,
0048                   std::vector<int>& vout)
0049 {
0050     // Find the intersection of the two sets
0051 
0052     std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(),
0053                           std::inserter(vout, vout.begin()));
0054 }
0055 
0056 /**
0057  * Function to return the Jaccard distance of two vectors
0058  */
0059 double jaccard_distance(const std::vector<int>& v1,
0060                         const std::vector<int>& v2)
0061 {
0062     // Sizes of both the sets
0063 
0064     double size_v1       = v1.size();
0065     double size_v2       = v2.size();
0066 
0067     // Get the intersection set
0068 
0069     std::vector<int> intersect;
0070     intersection(v1, v2, intersect);
0071 
0072     // Size of the intersection set
0073 
0074     double size_in       = intersect.size();
0075 
0076     // Calculate the Jaccard index
0077     // using the formula
0078 
0079     double jaccard_index = size_in / (size_v1 + size_v2 - size_in);
0080 
0081     // Calculate the Jaccard distance
0082     // using the formula
0083 
0084     double jaccard_dist  = 1 - jaccard_index;
0085 
0086     // Return the Jaccard distance
0087 
0088     return jaccard_dist;
0089 }
0090 
0091 QStringList toPaths(char** argv, int startIndex, int argc)
0092 {
0093     QStringList files;
0094 
0095     for (int i = startIndex ; i < argc ; ++i)
0096     {
0097         files << QString::fromLatin1(argv[i]);
0098     }
0099 
0100     return files;
0101 }
0102 
0103 QList<QImage> toImages(const QStringList& paths)
0104 {
0105     QList<QImage> images;
0106 
0107     Q_FOREACH (const QString& path, paths)
0108     {
0109         images << QImage(path);
0110     }
0111 
0112     return images;
0113 }
0114 
0115 int prepareForTrain(QString datasetPath,
0116                     QStringList& images,
0117                     std::vector<int>& testClusteredIndices)
0118 {
0119     if (!datasetPath.endsWith(QLatin1String("/")))
0120     {
0121         datasetPath.append(QLatin1String("/"));
0122     }
0123 
0124     QDir testSet(datasetPath);
0125     QStringList subjects = testSet.entryList(QDir::Dirs | QDir::NoDotAndDotDot | QDir::NoSymLinks);
0126     int nbOfClusters     = subjects.size();
0127 
0128     qCDebug(DIGIKAM_TESTS_LOG) << "Number of clusters to be defined" << nbOfClusters;
0129 
0130     for (int i = 1 ; i <= nbOfClusters ; ++i)
0131     {
0132         QString subjectPath               = QString::fromLatin1("%1%2")
0133                                                      .arg(datasetPath)
0134                                                      .arg(subjects.takeFirst());
0135         QDir subjectDir(subjectPath);
0136 
0137         QStringList files                 = subjectDir.entryList(QDir::Files);
0138         unsigned int nbOfFacesPerClusters = files.size();
0139 
0140         for (unsigned j = 1 ; j <= nbOfFacesPerClusters ; ++j)
0141         {
0142             QString path = QString::fromLatin1("%1/%2").arg(subjectPath)
0143                                                        .arg(files.takeFirst());
0144 
0145             testClusteredIndices.push_back(i - 1);
0146             images << path;
0147         }
0148     }
0149 
0150     qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters (prepareForTrain) " << nbOfClusters;
0151 
0152     return nbOfClusters;
0153 }
0154 
0155 QList<QRectF> processFaceDetection(const QString& imagePath, FaceDetector detector)
0156 {
0157     QList<QRectF> detectedFaces = detector.detectFaces(imagePath);
0158 
0159     qCDebug(DIGIKAM_TESTS_LOG) << "(Input CV) Found " << detectedFaces.size() << " faces";
0160 
0161     return detectedFaces;
0162 }
0163 
0164 QList<QImage> retrieveFaces(const QList<QImage>& images, const QList<QRectF>& rects)
0165 {
0166     QList<QImage> faces;
0167     unsigned index = 0;
0168 
0169     Q_FOREACH (const QRectF& rect, rects)
0170     {
0171         DImg temp(images.at(index));
0172         faces << temp.copyQImage(rect);
0173         ++index;
0174     }
0175 
0176     return faces;
0177 }
0178 
0179 void createClustersFromClusterIndices(const std::vector<int>& clusteredIndices,
0180                                       QList<std::vector<int>>& clusters)
0181 {
0182     int nbOfClusters = 0;
0183 
0184     for (size_t i = 0 ; i < clusteredIndices.size() ; ++i)
0185     {
0186         int nb = clusteredIndices[i];
0187 
0188         if (nb > nbOfClusters)
0189         {
0190             nbOfClusters = nb;
0191         }
0192     }
0193 
0194     nbOfClusters++;
0195 
0196     for (int i = 0 ; i < nbOfClusters ; ++i)
0197     {
0198         clusters << std::vector<int>();
0199     }
0200 
0201     qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters " << clusters.size();
0202 
0203     for (int i = 0 ; i < (int)clusteredIndices.size() ; ++i)
0204     {
0205         clusters[clusteredIndices[i]].push_back(i);
0206     }
0207 }
0208 
0209 void verifyClusteringResults(const std::vector<int>& clusteredIndices,
0210                              const std::vector<int>& testClusteredIndices,
0211                              const QStringList& dataset,
0212                              QStringList& falsePositiveCases)
0213 {
0214     QList<std::vector<int>> clusters, testClusters;
0215     createClustersFromClusterIndices(clusteredIndices, clusters);
0216     createClustersFromClusterIndices(testClusteredIndices, testClusters);
0217 
0218     std::set<int> falsePositivePoints;
0219     int testClustersSize = testClusters.size();
0220     std::vector<float> visited(testClustersSize, 1.0);
0221     std::vector<std::set<int>> lastVisit(testClustersSize, std::set<int>{});
0222 
0223     for (int i = 0 ; i < testClustersSize ; ++i)
0224     {
0225         std::vector<int> refSet = testClusters.at(i);
0226         double minDist          = 1.0;
0227         int indice              = 0;
0228 
0229         for (int j = 0 ; j < clusters.size() ; ++j)
0230         {
0231             double dist = jaccard_distance(refSet, clusters.at(j));
0232 
0233             if (dist < minDist)
0234             {
0235                 indice  = j;
0236                 minDist = dist;
0237             }
0238         }
0239 
0240         qCDebug(DIGIKAM_TESTS_LOG) << "testCluster " << i << " with group " << indice;
0241 
0242         std::vector<int> similarSet = clusters.at(indice);
0243 
0244         if (minDist < visited[indice])
0245         {
0246             visited[indice]            = minDist;
0247             std::set<int> lastVisitSet = lastVisit[indice];
0248             std::set<int> newVisitSet;
0249             std::set_symmetric_difference(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(),
0250                                           std::inserter(newVisitSet, newVisitSet.begin()));
0251 
0252             for (int elm: lastVisitSet)
0253             {
0254                 falsePositivePoints.erase(elm);
0255             }
0256 
0257             lastVisit[indice] = newVisitSet;
0258             falsePositivePoints.insert(newVisitSet.begin(), newVisitSet.end());
0259         }
0260         else
0261         {
0262             std::set_intersection(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(),
0263                                   std::inserter(falsePositivePoints, falsePositivePoints.begin()));
0264         }
0265     }
0266 
0267     for (const auto& indx : falsePositivePoints)
0268     {
0269         falsePositiveCases << dataset[indx];
0270     }
0271 }
0272 
0273 // --------------------------------------------------------------------------------------------------
0274 
0275 int main(int argc, char* argv[])
0276 {
0277     QCoreApplication app(argc, argv);
0278     app.setApplicationName(QString::fromLatin1("digikam"));          // for DB init.
0279 
0280     // Options for commandline parser
0281 
0282     QCommandLineParser parser;
0283     parser.addOption(QCommandLineOption(QLatin1String("db"),
0284                      QLatin1String("Faces database"),
0285                      QLatin1String("path to db folder")));
0286     parser.addHelpOption();
0287     parser.process(app);
0288 
0289     // Parse arguments
0290 
0291     bool optionErrors = false;
0292 
0293     if      (parser.optionNames().empty())
0294     {
0295         qCWarning(DIGIKAM_TESTS_LOG) << "No options!!!";
0296         optionErrors = true;
0297     }
0298     else if (!parser.isSet(QLatin1String("db")))
0299     {
0300         qCWarning(DIGIKAM_TESTS_LOG) << "Missing database for test!!!";
0301         optionErrors = true;
0302     }
0303 
0304     if (optionErrors)
0305     {
0306         parser.showHelp();
0307         return 1;
0308     }
0309 
0310     QString facedb         = parser.value(QLatin1String("db"));
0311 
0312     // Init config for digiKam
0313 
0314     DbEngineParameters prm = DbEngineParameters::parametersFromConfig();
0315     CoreDbAccess::setParameters(prm, CoreDbAccess::MainApplication);
0316     FacialRecognitionWrapper recognizer;
0317 
0318     //db.setRecognizerThreshold(0.91F);       // This is sensitive for the performance of face clustering
0319 
0320     // Construct test set, data set
0321 
0322     QStringList dataset;
0323     std::vector<int> testClusteredIndices;
0324     int nbOfClusters = prepareForTrain(facedb, dataset, testClusteredIndices);
0325 
0326     // Init FaceDetector used for detecting faces and bounding box
0327     // before recognizing
0328 
0329     FaceDetector detector;
0330 
0331     // Evaluation metrics
0332 
0333     unsigned totalClustered    = 0;
0334     unsigned elapsedClustering = 0;
0335 
0336     QStringList undetectedFaces;
0337 
0338     QList<QImage> detectedFaces;
0339     QList<QRectF> bboxes;
0340     QList<QImage> rawImages    = toImages(dataset);
0341 
0342     Q_FOREACH (const QImage& image, rawImages)
0343     {
0344         QString imagePath                 = dataset.takeFirst();
0345         QList<QRectF> detectedBoundingBox = processFaceDetection(imagePath, detector);
0346 
0347         if (detectedBoundingBox.size())
0348         {
0349             detectedFaces << image;
0350             bboxes        << detectedBoundingBox.first();
0351             dataset       << imagePath;
0352 
0353             ++totalClustered;
0354         }
0355         else
0356         {
0357             undetectedFaces << imagePath;
0358         }
0359     }
0360 
0361     std::vector<int> clusteredIndices(dataset.size(), -1);
0362     QList<QImage> faces = retrieveFaces(detectedFaces, bboxes);
0363 
0364     QElapsedTimer timer;
0365 
0366     timer.start();
0367 /*
0368     TODO: port to new API
0369     db.clusterFaces(faces, clusteredIndices, dataset, nbOfClusters);
0370 */
0371     elapsedClustering  += timer.elapsed();
0372 
0373     // Verify clustering
0374 
0375     QStringList falsePositiveCases;
0376     verifyClusteringResults(clusteredIndices, testClusteredIndices, dataset, falsePositiveCases);
0377 
0378     // Display results
0379 
0380     unsigned nbUndetectedFaces = undetectedFaces.size();
0381     qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbUndetectedFaces << " / " << dataset.size() + nbUndetectedFaces
0382              << " (" << float(nbUndetectedFaces) / (dataset.size() + nbUndetectedFaces) * 100 << "%)"
0383              << " faces cannot be detected";
0384 
0385     Q_FOREACH (const QString& path, undetectedFaces)
0386     {
0387         qCDebug(DIGIKAM_TESTS_LOG) << path;
0388     }
0389 
0390     unsigned nbOfFalsePositiveCases = falsePositiveCases.size();
0391     qCDebug(DIGIKAM_TESTS_LOG) << "\nFalse positive cases";
0392     qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbOfFalsePositiveCases << " / " << dataset.size()
0393              << " (" << float(nbOfFalsePositiveCases*100) / dataset.size()<< "%)"
0394              << " faces were wrongly clustered";
0395 
0396     Q_FOREACH (const QString& imagePath, falsePositiveCases)
0397     {
0398         qCDebug(DIGIKAM_TESTS_LOG) << imagePath;
0399     }
0400 
0401     qCDebug(DIGIKAM_TESTS_LOG) << "\n Time for clustering " << elapsedClustering << " ms";
0402 
0403     return 0;
0404 }