File indexing completed on 2025-01-19 03:57:42
0001 /* ============================================================ 0002 * 0003 * This file is a part of digiKam project 0004 * https://www.digikam.org 0005 * 0006 * Date : 2019-08-10 0007 * Description : CLI tool to test and verify clustering for Face Recognition 0008 * 0009 * SPDX-FileCopyrightText: 2019 by Thanh Trung Dinh <dinhthanhtrung1996 at gmail dot com> 0010 * 0011 * SPDX-License-Identifier: GPL-2.0-or-later 0012 * 0013 * ============================================================ */ 0014 0015 // C++ includes 0016 0017 #include <set> 0018 0019 // Qt includes 0020 0021 #include <QCoreApplication> 0022 #include <QDir> 0023 #include <QImage> 0024 #include <QElapsedTimer> 0025 #include <QCommandLineParser> 0026 #include <QList> 0027 0028 // Local includes 0029 0030 #include "digikam_debug.h" 0031 #include "dimg.h" 0032 #include "facescansettings.h" 0033 #include "facedetector.h" 0034 #include "coredbaccess.h" 0035 #include "dbengineparameters.h" 0036 #include "facialrecognition_wrapper.h" 0037 0038 using namespace Digikam; 0039 0040 // -------------------------------------------------------------------------------------------------- 0041 0042 /** 0043 * Function to return the 0044 * intersection vector of v1 and v2 0045 */ 0046 void intersection(const std::vector<int>& v1, 0047 const std::vector<int>& v2, 0048 std::vector<int>& vout) 0049 { 0050 // Find the intersection of the two sets 0051 0052 std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(), 0053 std::inserter(vout, vout.begin())); 0054 } 0055 0056 /** 0057 * Function to return the Jaccard distance of two vectors 0058 */ 0059 double jaccard_distance(const std::vector<int>& v1, 0060 const std::vector<int>& v2) 0061 { 0062 // Sizes of both the sets 0063 0064 double size_v1 = v1.size(); 0065 double size_v2 = v2.size(); 0066 0067 // Get the intersection set 0068 0069 std::vector<int> intersect; 0070 intersection(v1, v2, intersect); 0071 0072 // Size of the intersection set 0073 0074 double size_in = intersect.size(); 0075 0076 // Calculate the Jaccard index 0077 // using the formula 0078 0079 double jaccard_index = size_in / (size_v1 + size_v2 - size_in); 0080 0081 // Calculate the Jaccard distance 0082 // using the formula 0083 0084 double jaccard_dist = 1 - jaccard_index; 0085 0086 // Return the Jaccard distance 0087 0088 return jaccard_dist; 0089 } 0090 0091 QStringList toPaths(char** argv, int startIndex, int argc) 0092 { 0093 QStringList files; 0094 0095 for (int i = startIndex ; i < argc ; ++i) 0096 { 0097 files << QString::fromLatin1(argv[i]); 0098 } 0099 0100 return files; 0101 } 0102 0103 QList<QImage> toImages(const QStringList& paths) 0104 { 0105 QList<QImage> images; 0106 0107 Q_FOREACH (const QString& path, paths) 0108 { 0109 images << QImage(path); 0110 } 0111 0112 return images; 0113 } 0114 0115 int prepareForTrain(QString datasetPath, 0116 QStringList& images, 0117 std::vector<int>& testClusteredIndices) 0118 { 0119 if (!datasetPath.endsWith(QLatin1String("/"))) 0120 { 0121 datasetPath.append(QLatin1String("/")); 0122 } 0123 0124 QDir testSet(datasetPath); 0125 QStringList subjects = testSet.entryList(QDir::Dirs | QDir::NoDotAndDotDot | QDir::NoSymLinks); 0126 int nbOfClusters = subjects.size(); 0127 0128 qCDebug(DIGIKAM_TESTS_LOG) << "Number of clusters to be defined" << nbOfClusters; 0129 0130 for (int i = 1 ; i <= nbOfClusters ; ++i) 0131 { 0132 QString subjectPath = QString::fromLatin1("%1%2") 0133 .arg(datasetPath) 0134 .arg(subjects.takeFirst()); 0135 QDir subjectDir(subjectPath); 0136 0137 QStringList files = subjectDir.entryList(QDir::Files); 0138 unsigned int nbOfFacesPerClusters = files.size(); 0139 0140 for (unsigned j = 1 ; j <= nbOfFacesPerClusters ; ++j) 0141 { 0142 QString path = QString::fromLatin1("%1/%2").arg(subjectPath) 0143 .arg(files.takeFirst()); 0144 0145 testClusteredIndices.push_back(i - 1); 0146 images << path; 0147 } 0148 } 0149 0150 qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters (prepareForTrain) " << nbOfClusters; 0151 0152 return nbOfClusters; 0153 } 0154 0155 QList<QRectF> processFaceDetection(const QString& imagePath, FaceDetector detector) 0156 { 0157 QList<QRectF> detectedFaces = detector.detectFaces(imagePath); 0158 0159 qCDebug(DIGIKAM_TESTS_LOG) << "(Input CV) Found " << detectedFaces.size() << " faces"; 0160 0161 return detectedFaces; 0162 } 0163 0164 QList<QImage> retrieveFaces(const QList<QImage>& images, const QList<QRectF>& rects) 0165 { 0166 QList<QImage> faces; 0167 unsigned index = 0; 0168 0169 Q_FOREACH (const QRectF& rect, rects) 0170 { 0171 DImg temp(images.at(index)); 0172 faces << temp.copyQImage(rect); 0173 ++index; 0174 } 0175 0176 return faces; 0177 } 0178 0179 void createClustersFromClusterIndices(const std::vector<int>& clusteredIndices, 0180 QList<std::vector<int>>& clusters) 0181 { 0182 int nbOfClusters = 0; 0183 0184 for (size_t i = 0 ; i < clusteredIndices.size() ; ++i) 0185 { 0186 int nb = clusteredIndices[i]; 0187 0188 if (nb > nbOfClusters) 0189 { 0190 nbOfClusters = nb; 0191 } 0192 } 0193 0194 nbOfClusters++; 0195 0196 for (int i = 0 ; i < nbOfClusters ; ++i) 0197 { 0198 clusters << std::vector<int>(); 0199 } 0200 0201 qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters " << clusters.size(); 0202 0203 for (int i = 0 ; i < (int)clusteredIndices.size() ; ++i) 0204 { 0205 clusters[clusteredIndices[i]].push_back(i); 0206 } 0207 } 0208 0209 void verifyClusteringResults(const std::vector<int>& clusteredIndices, 0210 const std::vector<int>& testClusteredIndices, 0211 const QStringList& dataset, 0212 QStringList& falsePositiveCases) 0213 { 0214 QList<std::vector<int>> clusters, testClusters; 0215 createClustersFromClusterIndices(clusteredIndices, clusters); 0216 createClustersFromClusterIndices(testClusteredIndices, testClusters); 0217 0218 std::set<int> falsePositivePoints; 0219 int testClustersSize = testClusters.size(); 0220 std::vector<float> visited(testClustersSize, 1.0); 0221 std::vector<std::set<int>> lastVisit(testClustersSize, std::set<int>{}); 0222 0223 for (int i = 0 ; i < testClustersSize ; ++i) 0224 { 0225 std::vector<int> refSet = testClusters.at(i); 0226 double minDist = 1.0; 0227 int indice = 0; 0228 0229 for (int j = 0 ; j < clusters.size() ; ++j) 0230 { 0231 double dist = jaccard_distance(refSet, clusters.at(j)); 0232 0233 if (dist < minDist) 0234 { 0235 indice = j; 0236 minDist = dist; 0237 } 0238 } 0239 0240 qCDebug(DIGIKAM_TESTS_LOG) << "testCluster " << i << " with group " << indice; 0241 0242 std::vector<int> similarSet = clusters.at(indice); 0243 0244 if (minDist < visited[indice]) 0245 { 0246 visited[indice] = minDist; 0247 std::set<int> lastVisitSet = lastVisit[indice]; 0248 std::set<int> newVisitSet; 0249 std::set_symmetric_difference(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(), 0250 std::inserter(newVisitSet, newVisitSet.begin())); 0251 0252 for (int elm: lastVisitSet) 0253 { 0254 falsePositivePoints.erase(elm); 0255 } 0256 0257 lastVisit[indice] = newVisitSet; 0258 falsePositivePoints.insert(newVisitSet.begin(), newVisitSet.end()); 0259 } 0260 else 0261 { 0262 std::set_intersection(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(), 0263 std::inserter(falsePositivePoints, falsePositivePoints.begin())); 0264 } 0265 } 0266 0267 for (const auto& indx : falsePositivePoints) 0268 { 0269 falsePositiveCases << dataset[indx]; 0270 } 0271 } 0272 0273 // -------------------------------------------------------------------------------------------------- 0274 0275 int main(int argc, char* argv[]) 0276 { 0277 QCoreApplication app(argc, argv); 0278 app.setApplicationName(QString::fromLatin1("digikam")); // for DB init. 0279 0280 // Options for commandline parser 0281 0282 QCommandLineParser parser; 0283 parser.addOption(QCommandLineOption(QLatin1String("db"), 0284 QLatin1String("Faces database"), 0285 QLatin1String("path to db folder"))); 0286 parser.addHelpOption(); 0287 parser.process(app); 0288 0289 // Parse arguments 0290 0291 bool optionErrors = false; 0292 0293 if (parser.optionNames().empty()) 0294 { 0295 qCWarning(DIGIKAM_TESTS_LOG) << "No options!!!"; 0296 optionErrors = true; 0297 } 0298 else if (!parser.isSet(QLatin1String("db"))) 0299 { 0300 qCWarning(DIGIKAM_TESTS_LOG) << "Missing database for test!!!"; 0301 optionErrors = true; 0302 } 0303 0304 if (optionErrors) 0305 { 0306 parser.showHelp(); 0307 return 1; 0308 } 0309 0310 QString facedb = parser.value(QLatin1String("db")); 0311 0312 // Init config for digiKam 0313 0314 DbEngineParameters prm = DbEngineParameters::parametersFromConfig(); 0315 CoreDbAccess::setParameters(prm, CoreDbAccess::MainApplication); 0316 FacialRecognitionWrapper recognizer; 0317 0318 //db.setRecognizerThreshold(0.91F); // This is sensitive for the performance of face clustering 0319 0320 // Construct test set, data set 0321 0322 QStringList dataset; 0323 std::vector<int> testClusteredIndices; 0324 int nbOfClusters = prepareForTrain(facedb, dataset, testClusteredIndices); 0325 0326 // Init FaceDetector used for detecting faces and bounding box 0327 // before recognizing 0328 0329 FaceDetector detector; 0330 0331 // Evaluation metrics 0332 0333 unsigned totalClustered = 0; 0334 unsigned elapsedClustering = 0; 0335 0336 QStringList undetectedFaces; 0337 0338 QList<QImage> detectedFaces; 0339 QList<QRectF> bboxes; 0340 QList<QImage> rawImages = toImages(dataset); 0341 0342 Q_FOREACH (const QImage& image, rawImages) 0343 { 0344 QString imagePath = dataset.takeFirst(); 0345 QList<QRectF> detectedBoundingBox = processFaceDetection(imagePath, detector); 0346 0347 if (detectedBoundingBox.size()) 0348 { 0349 detectedFaces << image; 0350 bboxes << detectedBoundingBox.first(); 0351 dataset << imagePath; 0352 0353 ++totalClustered; 0354 } 0355 else 0356 { 0357 undetectedFaces << imagePath; 0358 } 0359 } 0360 0361 std::vector<int> clusteredIndices(dataset.size(), -1); 0362 QList<QImage> faces = retrieveFaces(detectedFaces, bboxes); 0363 0364 QElapsedTimer timer; 0365 0366 timer.start(); 0367 /* 0368 TODO: port to new API 0369 db.clusterFaces(faces, clusteredIndices, dataset, nbOfClusters); 0370 */ 0371 elapsedClustering += timer.elapsed(); 0372 0373 // Verify clustering 0374 0375 QStringList falsePositiveCases; 0376 verifyClusteringResults(clusteredIndices, testClusteredIndices, dataset, falsePositiveCases); 0377 0378 // Display results 0379 0380 unsigned nbUndetectedFaces = undetectedFaces.size(); 0381 qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbUndetectedFaces << " / " << dataset.size() + nbUndetectedFaces 0382 << " (" << float(nbUndetectedFaces) / (dataset.size() + nbUndetectedFaces) * 100 << "%)" 0383 << " faces cannot be detected"; 0384 0385 Q_FOREACH (const QString& path, undetectedFaces) 0386 { 0387 qCDebug(DIGIKAM_TESTS_LOG) << path; 0388 } 0389 0390 unsigned nbOfFalsePositiveCases = falsePositiveCases.size(); 0391 qCDebug(DIGIKAM_TESTS_LOG) << "\nFalse positive cases"; 0392 qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbOfFalsePositiveCases << " / " << dataset.size() 0393 << " (" << float(nbOfFalsePositiveCases*100) / dataset.size()<< "%)" 0394 << " faces were wrongly clustered"; 0395 0396 Q_FOREACH (const QString& imagePath, falsePositiveCases) 0397 { 0398 qCDebug(DIGIKAM_TESTS_LOG) << imagePath; 0399 } 0400 0401 qCDebug(DIGIKAM_TESTS_LOG) << "\n Time for clustering " << elapsedClustering << " ms"; 0402 0403 return 0; 0404 }