File indexing completed on 2025-03-09 03:56:55
0001 /* ============================================================ 0002 * 0003 * This file is a part of digiKam project 0004 * https://www.digikam.org 0005 * 0006 * Date : 2020-05-20 0007 * Description : Testing tool for dnn face recognition of face engines 0008 * 0009 * SPDX-FileCopyrightText: 2020 by Nghia Duong <minhnghiaduong997 at gmail dot com> 0010 * 0011 * SPDX-License-Identifier: GPL-2.0-or-later 0012 * 0013 * ============================================================ */ 0014 0015 // Qt includes 0016 0017 #include <QCoreApplication> 0018 #include <QCommandLineParser> 0019 #include <QDir> 0020 #include <QImage> 0021 #include <QElapsedTimer> 0022 #include <QHash> 0023 #include <QJsonObject> 0024 #include <QJsonDocument> 0025 #include <QRandomGenerator> 0026 0027 // Local includes 0028 0029 #include "digikam_debug.h" 0030 #include "opencvdnnfacedetector.h" 0031 #include "facedetector.h" 0032 #include "dnnfaceextractor.h" 0033 #include "opencvdnnfacerecognizer.h" 0034 #include "facialrecognition_wrapper.h" 0035 #include "dbengineparameters.h" 0036 #include "coredbaccess.h" 0037 0038 using namespace Digikam; 0039 0040 /** 0041 * @brief The Benchmark class: a benchmark especially made for testing the performance of face recognition module 0042 * on faces images. The dataset should contain only face data. It's not built for other purposes. 0043 */ 0044 class Benchmark : public QObject 0045 { 0046 Q_OBJECT 0047 0048 public: 0049 0050 Benchmark(QObject* const parent = nullptr); 0051 ~Benchmark() override; 0052 0053 public: 0054 0055 void verifyTestSet(); 0056 void splitData(const QDir& dataDir, float splitRatio); 0057 0058 public: 0059 0060 QCommandLineParser* m_parser; 0061 float m_error; 0062 int m_trainSize; 0063 int m_testSize; 0064 0065 private: 0066 0067 QImage* detect(const QImage& image) const; 0068 0069 QList<QImage*> detect(const QList<QImage*>& images) const; 0070 0071 bool preprocess(QImage* faceImg, cv::Mat& face) const; 0072 0073 public Q_SLOTS: 0074 0075 void fetchData(); 0076 void registerTrainingSet(); 0077 void saveData(); 0078 void testWriteDb(); 0079 void verifyKNearestDb(); 0080 0081 private: 0082 0083 QHash<QString, QList<QImage*> > m_trainSet; 0084 QHash<QString, QList<QImage*> > m_testSet; 0085 0086 FaceDetector* m_detector; 0087 FacialRecognitionWrapper* m_recognizer; 0088 }; 0089 0090 // -------------------------------------------------------- 0091 0092 // NOTE: dnn face detector can be parallelized but it takes longer than a single thread 0093 0094 class ParallelDetector: public cv::ParallelLoopBody 0095 { 0096 public: 0097 0098 ParallelDetector(FaceDetector* const detector, 0099 const QList<QImage*>& images, 0100 QVector<QList<QRectF> >& rects) 0101 : m_detector(detector), 0102 m_images (images), 0103 m_rects (rects) 0104 { 0105 m_rects.resize(images.size()); 0106 } 0107 0108 void operator()(const cv::Range& range) const override 0109 { 0110 for (int i = range.start ; i < range.end ; ++i) 0111 { 0112 if (m_images[i]->isNull()) 0113 { 0114 m_rects[i] = QList<QRectF>(); 0115 } 0116 else 0117 { 0118 m_rects[i] = m_detector->detectFaces(*m_images[i]); 0119 } 0120 } 0121 } 0122 0123 private: 0124 0125 FaceDetector* m_detector; 0126 const QList<QImage*>& m_images; 0127 QVector<QList<QRectF> >& m_rects; 0128 0129 private: 0130 0131 Q_DISABLE_COPY(ParallelDetector) 0132 }; 0133 0134 // -------------------------------------------------------- 0135 0136 Benchmark::Benchmark(QObject* const parent) 0137 : QObject (parent), 0138 m_parser (nullptr), 0139 m_error (-1), 0140 m_trainSize(0), 0141 m_testSize (0) 0142 { 0143 DbEngineParameters prm = DbEngineParameters::parametersFromConfig(); 0144 CoreDbAccess::setParameters(prm, CoreDbAccess::MainApplication); 0145 0146 m_detector = new FaceDetector(); 0147 m_recognizer = new FacialRecognitionWrapper(); 0148 0149 m_recognizer->clearAllTraining(QLatin1String("train face classifier")); 0150 m_recognizer->deleteIdentities(m_recognizer->allIdentities()); 0151 } 0152 0153 Benchmark::~Benchmark() 0154 { 0155 delete m_detector; 0156 delete m_recognizer; 0157 delete m_parser; 0158 } 0159 0160 void Benchmark::registerTrainingSet() 0161 { 0162 m_trainSize = 0; 0163 0164 QElapsedTimer timer; 0165 timer.start(); 0166 /* 0167 QMap<QString, QString> attributes; 0168 attributes[QLatin1String("fullName")] = m_trainSet.begin().key(); 0169 0170 Identity newIdentity = m_recognizer->addIdentity(attributes); 0171 0172 qCDebug(DIGIKAM_TESTS_LOG) << "add new identity to database" << newIdentity.id(); 0173 0174 m_recognizer->train(newIdentity, m_trainSet.begin().value(), QLatin1String("train face classifier")); 0175 0176 m_trainSize += m_trainSet.begin().value().size(); 0177 */ 0178 0179 for (QHash<QString, QList<QImage*> >::iterator iter = m_trainSet.begin(); 0180 iter != m_trainSet.end(); 0181 ++iter) 0182 { 0183 QMultiMap<QString, QString> attributes; 0184 attributes.insert(QLatin1String("fullName"), iter.key()); 0185 0186 Identity newIdentity = m_recognizer->addIdentity(attributes); 0187 0188 qCDebug(DIGIKAM_TESTS_LOG) << "add new identity to database" << newIdentity.id(); 0189 0190 m_recognizer->train(newIdentity, iter.value(), QLatin1String("train face classifier")); 0191 0192 m_trainSize += iter.value().size(); 0193 } 0194 0195 unsigned int elapsedDetection = timer.elapsed(); 0196 0197 qCDebug(DIGIKAM_TESTS_LOG) << "Registered << :" << m_trainSize 0198 << "faces in training set"; 0199 0200 if (m_trainSize != 0) 0201 { 0202 qCDebug(DIGIKAM_TESTS_LOG) << "with average" << float(elapsedDetection) / m_trainSize << "ms / face"; 0203 } 0204 } 0205 0206 void Benchmark::verifyTestSet() 0207 { 0208 int nbNotRecognize = 0; 0209 int nbWrongLabel = 0; 0210 m_testSize = 0; 0211 0212 QElapsedTimer timer; 0213 timer.start(); 0214 0215 for (QHash<QString, QList<QImage*> >::iterator iter = m_testSet.begin(); 0216 iter != m_testSet.end(); 0217 ++iter) 0218 { 0219 QList<Identity> predictions = m_recognizer->recognizeFaces(iter.value()); 0220 0221 for (int i = 0 ; i < predictions.size() ; ++i) 0222 { 0223 /* 0224 Identity prediction = m_recognizer->identity(ids[i]); 0225 */ 0226 if (predictions[i].isNull()) 0227 { 0228 if (m_trainSet.contains(iter.key())) 0229 { 0230 // cannot recognize when label is already register 0231 0232 ++nbNotRecognize; 0233 } 0234 } 0235 else if (predictions[i].attribute(QLatin1String("fullName")) != iter.key()) 0236 { 0237 // wrong label 0238 0239 ++nbWrongLabel; 0240 } 0241 } 0242 0243 m_testSize += iter.value().size(); 0244 } 0245 0246 unsigned int elapsedDetection = timer.elapsed(); 0247 0248 if (m_testSize == 0) 0249 { 0250 qCWarning(DIGIKAM_TESTS_LOG) << "test set is empty"; 0251 0252 return; 0253 } 0254 0255 m_error = float(nbNotRecognize + nbWrongLabel)/m_testSize; 0256 0257 qCDebug(DIGIKAM_TESTS_LOG) << "nb Not Recognized :" << nbNotRecognize; 0258 qCDebug(DIGIKAM_TESTS_LOG) << "nb Wrong Label :" << nbWrongLabel; 0259 0260 qCDebug(DIGIKAM_TESTS_LOG) << "Accuracy :" << (1 - m_error) * 100 << "%" 0261 << "on total" << m_trainSize << "training faces, and" 0262 << m_testSize << "test faces"; 0263 if (m_testSize != 0) 0264 { 0265 qCDebug(DIGIKAM_TESTS_LOG) << "(" << float(elapsedDetection) / m_testSize << "ms/face )"; 0266 } 0267 } 0268 0269 QImage* Benchmark::detect(const QImage& faceImg) const 0270 { 0271 if (faceImg.isNull()) 0272 { 0273 return nullptr; 0274 } 0275 0276 QList<QRectF> faces = m_detector->detectFaces(faceImg); 0277 0278 if (faces.isEmpty()) 0279 { 0280 return nullptr; 0281 } 0282 0283 QRect rect = FaceDetector::toAbsoluteRect(faces[0], faceImg.size()); 0284 0285 QImage* const croppedFace = new QImage(); 0286 *croppedFace = faceImg.copy(rect); 0287 0288 return croppedFace; 0289 } 0290 0291 QList<QImage*> Benchmark::detect(const QList<QImage*>& images) const 0292 { 0293 QVector<QList<QRectF> > faces; 0294 0295 cv::parallel_for_(cv::Range(0, images.size()), ParallelDetector(m_detector, images, faces)); 0296 0297 QList<QImage*> croppedFaces; 0298 0299 for (int i = 0 ; i < faces.size() ; ++i) 0300 { 0301 if (faces[i].isEmpty()) 0302 { 0303 croppedFaces << nullptr; 0304 } 0305 else 0306 { 0307 QRect rect = FaceDetector::toAbsoluteRect(faces[i][0], images[i]->size()); 0308 QImage* const croppedFace = new QImage(); 0309 *croppedFace = images[i]->copy(rect); 0310 0311 croppedFaces << croppedFace; 0312 } 0313 } 0314 0315 return croppedFaces; 0316 } 0317 0318 bool Benchmark::preprocess(QImage* faceImg, cv::Mat& face) const 0319 { 0320 QList<QRectF> faces = m_detector->detectFaces(*faceImg); 0321 0322 if (faces.isEmpty()) 0323 { 0324 return false; 0325 } 0326 0327 QRect rect = FaceDetector::toAbsoluteRect(faces[0], faceImg->size()); 0328 QImage croppedFace = faceImg->copy(rect); 0329 cv::Mat cvImageWrapper; 0330 0331 switch (croppedFace.format()) 0332 { 0333 case QImage::Format_RGB32: 0334 case QImage::Format_ARGB32: 0335 case QImage::Format_ARGB32_Premultiplied: 0336 0337 // I think we can ignore premultiplication when converting to grayscale 0338 0339 cvImageWrapper = cv::Mat(croppedFace.height(), croppedFace.width(), CV_8UC4, croppedFace.scanLine(0), croppedFace.bytesPerLine()); 0340 cv::cvtColor(cvImageWrapper, face, CV_RGBA2RGB); 0341 0342 break; 0343 0344 default: 0345 0346 croppedFace = croppedFace.convertToFormat(QImage::Format_RGB888); 0347 face = cv::Mat(croppedFace.height(), croppedFace.width(), CV_8UC3, croppedFace.scanLine(0), croppedFace.bytesPerLine()); 0348 /* 0349 cvtColor(cvImageWrapper, cvImage, CV_RGB2GRAY); 0350 */ 0351 break; 0352 } 0353 0354 return true; 0355 } 0356 0357 void Benchmark::splitData(const QDir& dataDir, float splitRatio) 0358 { 0359 int nbData = 0; 0360 0361 // Each subdirectory in data directory should match with a label 0362 0363 QFileInfoList subDirs = dataDir.entryInfoList(QDir::NoDotAndDotDot | QDir::Dirs, QDir::Name); 0364 0365 QElapsedTimer timer; 0366 timer.start(); 0367 0368 for (int i = 0 ; i < subDirs.size() ; ++i) 0369 { 0370 QDir subDir(subDirs[i].absoluteFilePath()); 0371 0372 QString label = subDirs[i].fileName(); 0373 0374 QFileInfoList filesInfo = subDir.entryInfoList(QDir::Files | QDir::Readable); 0375 0376 // suffle dataset 0377 0378 QList<QFileInfo>::iterator it = filesInfo.begin(); 0379 QList<QFileInfo>::iterator it1; 0380 0381 for (int j = 0 ; j < filesInfo.size() ; ++j) 0382 { 0383 int inc = QRandomGenerator::global()->bounded(filesInfo.size()); 0384 0385 it1 = filesInfo.begin(); 0386 it1 += inc; 0387 0388 std::swap(*(it++), *(it1)); 0389 } 0390 0391 //QString faceDir = QLatin1String("./cropped_face/"); 0392 /* 0393 QList<QImage*> images; 0394 0395 for (int i = 0 ; i < filesInfo.size() ; ++i) 0396 { 0397 images << new QImage(filesInfo[i].absoluteFilePath()); 0398 } 0399 0400 QList<QImage*> croppedFaces = detect(images); 0401 0402 // split train/test 0403 0404 for (int i = 0 ; i < croppedFaces.size() ; ++i) 0405 { 0406 if (croppedFaces[i]) 0407 { 0408 croppedFaces[i]->save(faceDir + label + QLatin1String("_") + QString::number(i) + QLatin1String(".png"), "PNG"); 0409 } 0410 0411 if (i < (filesInfo.size() * splitRatio)) 0412 { 0413 if (croppedFaces[i] && !croppedFaces[i]->isNull()) 0414 { 0415 m_trainSet[label].append(croppedFaces[i]); 0416 ++nbData; 0417 } 0418 } 0419 else 0420 { 0421 if (croppedFaces[i] && !croppedFaces[i]->isNull()) 0422 { 0423 m_testSet[label].append(croppedFaces[i]); 0424 ++nbData; 0425 } 0426 } 0427 } 0428 */ 0429 for (int j = 0 ; j < filesInfo.size() ; ++j) 0430 { 0431 QImage img(filesInfo[j].absoluteFilePath()); 0432 0433 QImage* const croppedFace = detect(img); 0434 0435 // Save cropped face for detection testing 0436 0437 if (croppedFace) 0438 { 0439 /* 0440 croppedFace->save(faceDir + label + QLatin1String("_") + QString::number(i) + QLatin1String(".png"), "PNG"); 0441 */ 0442 } 0443 0444 if (j < (filesInfo.size() * splitRatio)) 0445 { 0446 if (croppedFace && !croppedFace->isNull()) 0447 { 0448 m_trainSet[label].append(croppedFace); 0449 ++nbData; 0450 } 0451 } 0452 else 0453 { 0454 if (croppedFace && !croppedFace->isNull()) 0455 { 0456 m_testSet[label].append(croppedFace); 0457 ++nbData; 0458 } 0459 } 0460 } 0461 } 0462 0463 unsigned int elapsedDetection = timer.elapsed(); 0464 0465 qCDebug(DIGIKAM_TESTS_LOG) << "Fetched dataset with" << nbData << "samples"; 0466 0467 if (nbData != 0) 0468 { 0469 qCDebug(DIGIKAM_TESTS_LOG) << "with average" << float(elapsedDetection) / nbData << "ms/image."; 0470 } 0471 } 0472 0473 void Benchmark::fetchData() 0474 { 0475 if (! m_parser->isSet(QLatin1String("dataset"))) 0476 { 0477 qWarning("Data set is not set !!!"); 0478 0479 return; 0480 } 0481 0482 QDir dataset(m_parser->value(QLatin1String("dataset"))); 0483 0484 float splitRatio = 0.8; 0485 0486 if (m_parser->isSet(QLatin1String("split"))) 0487 { 0488 splitRatio = m_parser->value(QLatin1String("split")).toFloat(); 0489 } 0490 0491 splitData(dataset, splitRatio); 0492 } 0493 0494 void Benchmark::saveData() 0495 { 0496 QDir dataDir(m_parser->value(QLatin1String("dataset"))); 0497 QFileInfoList subDirs = dataDir.entryInfoList(QDir::NoDotAndDotDot | QDir::Dirs, QDir::Name); 0498 0499 QElapsedTimer timer; 0500 timer.start(); 0501 0502 QJsonArray faceEmbeddingArray; 0503 0504 for (int i = 0 ; i < subDirs.size() ; ++i) 0505 { 0506 QDir subDir(subDirs[i].absoluteFilePath()); 0507 QFileInfoList filesInfo = subDir.entryInfoList(QDir::Files | QDir::Readable); 0508 0509 for (int j = 0 ; j < filesInfo.size() ; ++j) 0510 { 0511 QImage* const img = new QImage(filesInfo[j].absoluteFilePath()); 0512 0513 if (! img->isNull()) 0514 { 0515 cv::Mat face; 0516 /* 0517 if (preprocess(img, face)) 0518 { 0519 0520 Identity newIdentity = m_recognizer->newIdentity(face); 0521 0522 QJsonObject identityJson; 0523 identityJson[QLatin1String("id")] = i; 0524 identityJson[QLatin1String("faceembedding")] = QJsonDocument::fromJson(newIdentity.attribute(QLatin1String("faceEmbedding")).toLatin1()).array(); 0525 0526 faceEmbeddingArray.append(identityJson); 0527 } 0528 */ 0529 } 0530 } 0531 } 0532 0533 qCDebug(DIGIKAM_TESTS_LOG) << "Save face embedding in" << timer.elapsed() << "ms/face"; 0534 0535 QFile dataFile(dataDir.dirName() + QLatin1String(".json")); 0536 0537 if (!dataFile.open(QIODevice::WriteOnly)) 0538 { 0539 qWarning("Couldn't open save file."); 0540 return; 0541 } 0542 0543 QJsonDocument saveDoc(faceEmbeddingArray); 0544 dataFile.write(saveDoc.toJson()); 0545 0546 dataFile.close(); 0547 } 0548 0549 void Benchmark::testWriteDb() 0550 { 0551 QDir dataDir(m_parser->value(QLatin1String("dataset"))); 0552 0553 QFile dataFile(dataDir.dirName() + QLatin1String(".json")); 0554 0555 if (!dataFile.open(QIODevice::ReadOnly)) 0556 { 0557 qWarning("Couldn't open data file."); 0558 return; 0559 } 0560 0561 QByteArray saveData = dataFile.readAll(); 0562 dataFile.close(); 0563 0564 QJsonDocument loadDoc(QJsonDocument::fromJson(saveData)); 0565 0566 QJsonArray data = loadDoc.array(); 0567 0568 QElapsedTimer timer; 0569 timer.start(); 0570 0571 for (int i = 0 ; i < data.size() ; ++i) 0572 { 0573 QJsonObject object = data[i].toObject(); 0574 std::vector<float> faceEmbedding = DNNFaceExtractor::decodeVector(object[QLatin1String("faceembedding")].toArray()); 0575 /* 0576 m_recognizer->insertData(DNNFaceExtractor::vectortomat(faceEmbedding), object[QLatin1String("id")].toInt()); 0577 */ 0578 } 0579 0580 qCDebug(DIGIKAM_TESTS_LOG) << "write face embedding to spatial database with average" << timer.elapsed() /data.size() << "ms/faceEmbedding"; 0581 } 0582 0583 void Benchmark::verifyKNearestDb() 0584 { 0585 QDir dataDir(m_parser->value(QLatin1String("dataset"))); 0586 0587 QFile dataFile(dataDir.dirName() + QLatin1String(".json")); 0588 0589 if (!dataFile.open(QIODevice::ReadOnly)) 0590 { 0591 qWarning("Couldn't open data file."); 0592 return; 0593 } 0594 0595 QByteArray saveData = dataFile.readAll(); 0596 dataFile.close(); 0597 0598 QJsonDocument loadDoc(QJsonDocument::fromJson(saveData)); 0599 0600 QJsonArray data = loadDoc.array(); 0601 int nbCorrect = 0; 0602 0603 QElapsedTimer timer; 0604 timer.start(); 0605 0606 for (int i = 0 ; i < data.size() ; ++i) 0607 { 0608 QJsonObject object = data[i].toObject(); 0609 std::vector<float> faceEmbedding = DNNFaceExtractor::decodeVector(object[QLatin1String("faceembedding")].toArray()); 0610 int label = object[QLatin1String("id")].toInt(); 0611 0612 QMap<double, QVector<int> > closestNeighbors 0613 /* 0614 = m_recognizer->getClosestNodes(DNNFaceExtractor::vectortomat(faceEmbedding), 1.0, 5) 0615 */ 0616 ; 0617 0618 QMap<int, QVector<double> > votingGroups; 0619 0620 for (QMap<double, QVector<int> >::const_iterator iter = closestNeighbors.cbegin(); 0621 iter != closestNeighbors.cend(); 0622 ++iter) 0623 { 0624 for (int j = 0 ; j < iter.value().size() ; ++j) 0625 { 0626 votingGroups[iter.value()[j]].append(iter.key()); 0627 } 0628 } 0629 0630 double maxScore = 0; 0631 int prediction = -1; 0632 0633 for (QMap<int, QVector<double> >::const_iterator group = votingGroups.cbegin(); 0634 group != votingGroups.cend(); 0635 ++group) 0636 { 0637 double score = 0; 0638 0639 for (int j = 0 ; j < group.value().size() ; ++j) 0640 { 0641 score += (1 - group.value()[j]); 0642 } 0643 0644 if (score > maxScore) 0645 { 0646 maxScore = score; 0647 prediction = group.key(); 0648 } 0649 } 0650 0651 if (label == prediction) 0652 { 0653 ++nbCorrect; 0654 } 0655 } 0656 0657 if (data.size() != 0) 0658 { 0659 qCDebug(DIGIKAM_TESTS_LOG) << "Accuracy" << (float(nbCorrect) / data.size())*100 0660 << "with average" << timer.elapsed() / data.size() 0661 << "ms/faceEmbedding"; 0662 } 0663 } 0664 0665 // -------------------------------------------------------- 0666 0667 QCommandLineParser* parseOptions(const QCoreApplication& app) 0668 { 0669 QCommandLineParser* const parser = new QCommandLineParser(); 0670 parser->addOption(QCommandLineOption(QLatin1String("dataset"), QLatin1String("Data set folder"), QLatin1String("path relative to data folder"))); 0671 parser->addOption(QCommandLineOption(QLatin1String("split"), QLatin1String("Split ratio"), QLatin1String("split ratio of training data"))); 0672 parser->addHelpOption(); 0673 parser->process(app); 0674 0675 return parser; 0676 } 0677 0678 int main(int argc, char** argv) 0679 { 0680 QCoreApplication app(argc, argv); 0681 0682 app.setApplicationName(QString::fromLatin1("digikam")); // for DB init. 0683 0684 Benchmark benchmark; 0685 benchmark.m_parser = parseOptions(app); 0686 /* 0687 benchmark.saveData(); 0688 benchmark.testWriteDb(); 0689 benchmark.verifyKNearestDb(); 0690 */ 0691 benchmark.fetchData(); 0692 benchmark.registerTrainingSet(); 0693 benchmark.verifyTestSet(); 0694 0695 return 0; 0696 } 0697 0698 #include "benchmark_recognition_cli.moc"