File indexing completed on 2025-03-09 03:56:55

0001 /* ============================================================
0002  *
0003  * This file is a part of digiKam project
0004  * https://www.digikam.org
0005  *
0006  * Date        : 2020-05-20
0007  * Description : Testing tool for dnn face recognition of face engines
0008  *
0009  * SPDX-FileCopyrightText: 2020 by Nghia Duong <minhnghiaduong997 at gmail dot com>
0010  *
0011  * SPDX-License-Identifier: GPL-2.0-or-later
0012  *
0013  * ============================================================ */
0014 
0015 // Qt includes
0016 
0017 #include <QCoreApplication>
0018 #include <QCommandLineParser>
0019 #include <QDir>
0020 #include <QImage>
0021 #include <QElapsedTimer>
0022 #include <QHash>
0023 #include <QJsonObject>
0024 #include <QJsonDocument>
0025 #include <QRandomGenerator>
0026 
0027 // Local includes
0028 
0029 #include "digikam_debug.h"
0030 #include "opencvdnnfacedetector.h"
0031 #include "facedetector.h"
0032 #include "dnnfaceextractor.h"
0033 #include "opencvdnnfacerecognizer.h"
0034 #include "facialrecognition_wrapper.h"
0035 #include "dbengineparameters.h"
0036 #include "coredbaccess.h"
0037 
0038 using namespace Digikam;
0039 
0040 /**
0041  * @brief The Benchmark class: a benchmark especially made for testing the performance of face recognition module
0042  * on faces images. The dataset should contain only face data. It's not built for other purposes.
0043  */
0044 class Benchmark : public QObject
0045 {
0046     Q_OBJECT
0047 
0048 public:
0049 
0050     Benchmark(QObject* const parent = nullptr);
0051     ~Benchmark() override;
0052 
0053 public:
0054 
0055     void verifyTestSet();
0056     void splitData(const QDir& dataDir, float splitRatio);
0057 
0058 public:
0059 
0060     QCommandLineParser* m_parser;
0061     float               m_error;
0062     int                 m_trainSize;
0063     int                 m_testSize;
0064 
0065 private:
0066 
0067     QImage* detect(const QImage& image)                 const;
0068 
0069     QList<QImage*> detect(const QList<QImage*>& images) const;
0070 
0071     bool preprocess(QImage* faceImg, cv::Mat& face)     const;
0072 
0073 public Q_SLOTS:
0074 
0075     void fetchData();
0076     void registerTrainingSet();
0077     void saveData();
0078     void testWriteDb();
0079     void verifyKNearestDb();
0080 
0081 private:
0082 
0083     QHash<QString, QList<QImage*> > m_trainSet;
0084     QHash<QString, QList<QImage*> > m_testSet;
0085 
0086     FaceDetector*                   m_detector;
0087     FacialRecognitionWrapper*       m_recognizer;
0088 };
0089 
0090 // --------------------------------------------------------
0091 
0092 // NOTE: dnn face detector can be parallelized but it takes longer than a single thread
0093 
0094 class ParallelDetector: public cv::ParallelLoopBody
0095 {
0096 public:
0097 
0098     ParallelDetector(FaceDetector* const detector,
0099                      const QList<QImage*>& images,
0100                      QVector<QList<QRectF> >& rects)
0101         : m_detector(detector),
0102           m_images  (images),
0103           m_rects   (rects)
0104     {
0105         m_rects.resize(images.size());
0106     }
0107 
0108     void operator()(const cv::Range& range) const override
0109     {
0110         for (int i = range.start ; i < range.end ; ++i)
0111         {
0112             if (m_images[i]->isNull())
0113             {
0114                 m_rects[i] = QList<QRectF>();
0115             }
0116             else
0117             {
0118                 m_rects[i] = m_detector->detectFaces(*m_images[i]);
0119             }
0120         }
0121     }
0122 
0123 private:
0124 
0125     FaceDetector*            m_detector;
0126     const QList<QImage*>&    m_images;
0127     QVector<QList<QRectF> >& m_rects;
0128 
0129 private:
0130 
0131     Q_DISABLE_COPY(ParallelDetector)
0132 };
0133 
0134 // --------------------------------------------------------
0135 
0136 Benchmark::Benchmark(QObject* const parent)
0137     : QObject    (parent),
0138       m_parser   (nullptr),
0139       m_error    (-1),
0140       m_trainSize(0),
0141       m_testSize (0)
0142 {
0143     DbEngineParameters prm = DbEngineParameters::parametersFromConfig();
0144     CoreDbAccess::setParameters(prm, CoreDbAccess::MainApplication);
0145 
0146     m_detector             = new FaceDetector();
0147     m_recognizer           = new FacialRecognitionWrapper();
0148 
0149     m_recognizer->clearAllTraining(QLatin1String("train face classifier"));
0150     m_recognizer->deleteIdentities(m_recognizer->allIdentities());
0151 }
0152 
0153 Benchmark::~Benchmark()
0154 {
0155     delete m_detector;
0156     delete m_recognizer;
0157     delete m_parser;
0158 }
0159 
0160 void Benchmark::registerTrainingSet()
0161 {
0162     m_trainSize = 0;
0163 
0164     QElapsedTimer timer;
0165     timer.start();
0166 /*
0167     QMap<QString, QString> attributes;
0168     attributes[QLatin1String("fullName")] = m_trainSet.begin().key();
0169 
0170     Identity newIdentity = m_recognizer->addIdentity(attributes);
0171 
0172     qCDebug(DIGIKAM_TESTS_LOG) << "add new identity to database" << newIdentity.id();
0173 
0174     m_recognizer->train(newIdentity, m_trainSet.begin().value(), QLatin1String("train face classifier"));
0175 
0176     m_trainSize += m_trainSet.begin().value().size();
0177 */
0178 
0179     for (QHash<QString, QList<QImage*> >::iterator iter  = m_trainSet.begin();
0180                                                    iter != m_trainSet.end();
0181                                                    ++iter)
0182     {
0183         QMultiMap<QString, QString> attributes;
0184         attributes.insert(QLatin1String("fullName"), iter.key());
0185 
0186         Identity newIdentity = m_recognizer->addIdentity(attributes);
0187 
0188         qCDebug(DIGIKAM_TESTS_LOG) << "add new identity to database" << newIdentity.id();
0189 
0190         m_recognizer->train(newIdentity, iter.value(), QLatin1String("train face classifier"));
0191 
0192         m_trainSize += iter.value().size();
0193     }
0194 
0195     unsigned int elapsedDetection = timer.elapsed();
0196 
0197     qCDebug(DIGIKAM_TESTS_LOG) << "Registered <<  :" << m_trainSize
0198              << "faces in training set";
0199 
0200     if (m_trainSize != 0)
0201     {
0202         qCDebug(DIGIKAM_TESTS_LOG) << "with average" << float(elapsedDetection) / m_trainSize << "ms / face";
0203     }
0204 }
0205 
0206 void Benchmark::verifyTestSet()
0207 {
0208     int nbNotRecognize = 0;
0209     int nbWrongLabel   = 0;
0210     m_testSize         = 0;
0211 
0212     QElapsedTimer timer;
0213     timer.start();
0214 
0215     for (QHash<QString, QList<QImage*> >::iterator iter  = m_testSet.begin();
0216                                                    iter != m_testSet.end();
0217                                                    ++iter)
0218     {
0219         QList<Identity> predictions = m_recognizer->recognizeFaces(iter.value());
0220 
0221         for (int i = 0 ; i < predictions.size() ; ++i)
0222         {
0223 /*
0224             Identity prediction = m_recognizer->identity(ids[i]);
0225 */
0226             if      (predictions[i].isNull())
0227             {
0228                 if (m_trainSet.contains(iter.key()))
0229                 {
0230                     // cannot recognize when label is already register
0231 
0232                     ++nbNotRecognize;
0233                 }
0234             }
0235             else if (predictions[i].attribute(QLatin1String("fullName")) != iter.key())
0236             {
0237                 // wrong label
0238 
0239                 ++nbWrongLabel;
0240             }
0241         }
0242 
0243         m_testSize += iter.value().size();
0244     }
0245 
0246     unsigned int elapsedDetection = timer.elapsed();
0247 
0248     if (m_testSize == 0)
0249     {
0250         qCWarning(DIGIKAM_TESTS_LOG) << "test set is empty";
0251 
0252         return;
0253     }
0254 
0255     m_error = float(nbNotRecognize + nbWrongLabel)/m_testSize;
0256 
0257     qCDebug(DIGIKAM_TESTS_LOG) << "nb Not Recognized :" << nbNotRecognize;
0258     qCDebug(DIGIKAM_TESTS_LOG) << "nb Wrong Label :"    << nbWrongLabel;
0259 
0260     qCDebug(DIGIKAM_TESTS_LOG) << "Accuracy :" << (1 - m_error) * 100 << "%"
0261              << "on total"   << m_trainSize << "training faces, and"
0262                              << m_testSize << "test faces";
0263     if (m_testSize != 0)
0264     {
0265         qCDebug(DIGIKAM_TESTS_LOG) << "(" << float(elapsedDetection) / m_testSize << "ms/face )";
0266     }
0267 }
0268 
0269 QImage* Benchmark::detect(const QImage& faceImg) const
0270 {
0271     if (faceImg.isNull())
0272     {
0273         return nullptr;
0274     }
0275 
0276     QList<QRectF> faces = m_detector->detectFaces(faceImg);
0277 
0278     if (faces.isEmpty())
0279     {
0280         return nullptr;
0281     }
0282 
0283     QRect rect                = FaceDetector::toAbsoluteRect(faces[0], faceImg.size());
0284 
0285     QImage* const croppedFace = new QImage();
0286     *croppedFace              = faceImg.copy(rect);
0287 
0288     return croppedFace;
0289 }
0290 
0291 QList<QImage*> Benchmark::detect(const QList<QImage*>& images) const
0292 {
0293     QVector<QList<QRectF> > faces;
0294 
0295     cv::parallel_for_(cv::Range(0, images.size()), ParallelDetector(m_detector, images, faces));
0296 
0297     QList<QImage*> croppedFaces;
0298 
0299     for (int i = 0 ; i < faces.size() ; ++i)
0300     {
0301         if (faces[i].isEmpty())
0302         {
0303             croppedFaces << nullptr;
0304         }
0305         else
0306         {
0307             QRect rect                = FaceDetector::toAbsoluteRect(faces[i][0], images[i]->size());
0308             QImage* const croppedFace = new QImage();
0309             *croppedFace              = images[i]->copy(rect);
0310 
0311             croppedFaces << croppedFace;
0312         }
0313     }
0314 
0315     return croppedFaces;
0316 }
0317 
0318 bool Benchmark::preprocess(QImage* faceImg, cv::Mat& face) const
0319 {
0320     QList<QRectF> faces = m_detector->detectFaces(*faceImg);
0321 
0322     if (faces.isEmpty())
0323     {
0324         return false;
0325     }
0326 
0327     QRect rect          = FaceDetector::toAbsoluteRect(faces[0], faceImg->size());
0328     QImage croppedFace  = faceImg->copy(rect);
0329     cv::Mat cvImageWrapper;
0330 
0331     switch (croppedFace.format())
0332     {
0333         case QImage::Format_RGB32:
0334         case QImage::Format_ARGB32:
0335         case QImage::Format_ARGB32_Premultiplied:
0336 
0337             // I think we can ignore premultiplication when converting to grayscale
0338 
0339             cvImageWrapper = cv::Mat(croppedFace.height(), croppedFace.width(), CV_8UC4, croppedFace.scanLine(0), croppedFace.bytesPerLine());
0340             cv::cvtColor(cvImageWrapper, face, CV_RGBA2RGB);
0341 
0342             break;
0343 
0344         default:
0345 
0346             croppedFace = croppedFace.convertToFormat(QImage::Format_RGB888);
0347             face        = cv::Mat(croppedFace.height(), croppedFace.width(), CV_8UC3, croppedFace.scanLine(0), croppedFace.bytesPerLine());
0348 /*
0349             cvtColor(cvImageWrapper, cvImage, CV_RGB2GRAY);
0350 */
0351             break;
0352     }
0353 
0354     return true;
0355 }
0356 
0357 void Benchmark::splitData(const QDir& dataDir, float splitRatio)
0358 {
0359     int nbData = 0;
0360 
0361     // Each subdirectory in data directory should match with a label
0362 
0363     QFileInfoList subDirs = dataDir.entryInfoList(QDir::NoDotAndDotDot | QDir::Dirs, QDir::Name);
0364 
0365     QElapsedTimer timer;
0366     timer.start();
0367 
0368     for (int i = 0 ; i < subDirs.size() ; ++i)
0369     {
0370         QDir subDir(subDirs[i].absoluteFilePath());
0371 
0372         QString label = subDirs[i].fileName();
0373 
0374         QFileInfoList filesInfo = subDir.entryInfoList(QDir::Files | QDir::Readable);
0375 
0376         // suffle dataset
0377 
0378         QList<QFileInfo>::iterator it = filesInfo.begin();
0379         QList<QFileInfo>::iterator it1;
0380 
0381         for (int j = 0 ; j < filesInfo.size() ; ++j)
0382         {
0383             int inc = QRandomGenerator::global()->bounded(filesInfo.size());
0384 
0385             it1     = filesInfo.begin();
0386             it1    += inc;
0387 
0388             std::swap(*(it++), *(it1));
0389          }
0390 
0391         //QString faceDir = QLatin1String("./cropped_face/");
0392 /*
0393         QList<QImage*> images;
0394 
0395         for (int i = 0 ; i < filesInfo.size() ; ++i)
0396         {
0397             images << new QImage(filesInfo[i].absoluteFilePath());
0398         }
0399 
0400         QList<QImage*> croppedFaces = detect(images);
0401 
0402         // split train/test
0403 
0404         for (int i = 0 ; i < croppedFaces.size() ; ++i)
0405         {
0406             if (croppedFaces[i])
0407             {
0408                 croppedFaces[i]->save(faceDir + label + QLatin1String("_") + QString::number(i) + QLatin1String(".png"), "PNG");
0409             }
0410 
0411             if (i < (filesInfo.size() * splitRatio))
0412             {
0413                 if (croppedFaces[i] && !croppedFaces[i]->isNull())
0414                 {
0415                     m_trainSet[label].append(croppedFaces[i]);
0416                     ++nbData;
0417                 }
0418             }
0419             else
0420             {
0421                 if (croppedFaces[i] && !croppedFaces[i]->isNull())
0422                 {
0423                     m_testSet[label].append(croppedFaces[i]);
0424                     ++nbData;
0425                 }
0426             }
0427         }
0428 */
0429         for (int j = 0 ; j < filesInfo.size() ; ++j)
0430         {
0431             QImage img(filesInfo[j].absoluteFilePath());
0432 
0433             QImage* const croppedFace = detect(img);
0434 
0435             // Save cropped face for detection testing
0436 
0437             if (croppedFace)
0438             {
0439 /*
0440                 croppedFace->save(faceDir + label + QLatin1String("_") + QString::number(i) + QLatin1String(".png"), "PNG");
0441 */
0442             }
0443 
0444             if (j < (filesInfo.size() * splitRatio))
0445             {
0446                 if (croppedFace && !croppedFace->isNull())
0447                 {
0448                     m_trainSet[label].append(croppedFace);
0449                     ++nbData;
0450                 }
0451             }
0452             else
0453             {
0454                 if (croppedFace && !croppedFace->isNull())
0455                 {
0456                     m_testSet[label].append(croppedFace);
0457                     ++nbData;
0458                 }
0459             }
0460         }
0461     }
0462 
0463     unsigned int elapsedDetection = timer.elapsed();
0464 
0465     qCDebug(DIGIKAM_TESTS_LOG) << "Fetched dataset with" << nbData << "samples";
0466 
0467     if (nbData != 0)
0468     {
0469         qCDebug(DIGIKAM_TESTS_LOG) << "with average" << float(elapsedDetection) / nbData << "ms/image.";
0470     }
0471 }
0472 
0473 void Benchmark::fetchData()
0474 {
0475     if (! m_parser->isSet(QLatin1String("dataset")))
0476     {
0477         qWarning("Data set is not set !!!");
0478 
0479         return;
0480     }
0481 
0482     QDir dataset(m_parser->value(QLatin1String("dataset")));
0483 
0484     float splitRatio = 0.8;
0485 
0486     if (m_parser->isSet(QLatin1String("split")))
0487     {
0488         splitRatio = m_parser->value(QLatin1String("split")).toFloat();
0489     }
0490 
0491     splitData(dataset, splitRatio);
0492 }
0493 
0494 void Benchmark::saveData()
0495 {
0496     QDir dataDir(m_parser->value(QLatin1String("dataset")));
0497     QFileInfoList subDirs = dataDir.entryInfoList(QDir::NoDotAndDotDot | QDir::Dirs, QDir::Name);
0498 
0499     QElapsedTimer timer;
0500     timer.start();
0501 
0502     QJsonArray faceEmbeddingArray;
0503 
0504     for (int i = 0 ; i < subDirs.size() ; ++i)
0505     {
0506         QDir subDir(subDirs[i].absoluteFilePath());
0507         QFileInfoList filesInfo = subDir.entryInfoList(QDir::Files | QDir::Readable);
0508 
0509         for (int j = 0 ; j < filesInfo.size() ; ++j)
0510         {
0511             QImage* const img = new QImage(filesInfo[j].absoluteFilePath());
0512 
0513             if (! img->isNull())
0514             {
0515                 cv::Mat face;
0516 /*
0517                 if (preprocess(img, face))
0518                 {
0519 
0520                     Identity newIdentity = m_recognizer->newIdentity(face);
0521 
0522                     QJsonObject identityJson;
0523                     identityJson[QLatin1String("id")] = i;
0524                     identityJson[QLatin1String("faceembedding")] = QJsonDocument::fromJson(newIdentity.attribute(QLatin1String("faceEmbedding")).toLatin1()).array();
0525 
0526                     faceEmbeddingArray.append(identityJson);
0527                 }
0528 */
0529             }
0530         }
0531     }
0532 
0533     qCDebug(DIGIKAM_TESTS_LOG) << "Save face embedding in" << timer.elapsed() << "ms/face";
0534 
0535     QFile dataFile(dataDir.dirName() + QLatin1String(".json"));
0536 
0537     if (!dataFile.open(QIODevice::WriteOnly))
0538     {
0539         qWarning("Couldn't open save file.");
0540         return;
0541     }
0542 
0543     QJsonDocument saveDoc(faceEmbeddingArray);
0544     dataFile.write(saveDoc.toJson());
0545 
0546     dataFile.close();
0547 }
0548 
0549 void Benchmark::testWriteDb()
0550 {
0551     QDir dataDir(m_parser->value(QLatin1String("dataset")));
0552 
0553     QFile dataFile(dataDir.dirName() + QLatin1String(".json"));
0554 
0555     if (!dataFile.open(QIODevice::ReadOnly))
0556     {
0557         qWarning("Couldn't open data file.");
0558         return;
0559     }
0560 
0561     QByteArray saveData = dataFile.readAll();
0562     dataFile.close();
0563 
0564     QJsonDocument loadDoc(QJsonDocument::fromJson(saveData));
0565 
0566     QJsonArray data     = loadDoc.array();
0567 
0568     QElapsedTimer timer;
0569     timer.start();
0570 
0571     for (int i = 0 ; i < data.size() ; ++i)
0572     {
0573         QJsonObject object               = data[i].toObject();
0574         std::vector<float> faceEmbedding = DNNFaceExtractor::decodeVector(object[QLatin1String("faceembedding")].toArray());
0575 /*
0576         m_recognizer->insertData(DNNFaceExtractor::vectortomat(faceEmbedding), object[QLatin1String("id")].toInt());
0577 */
0578     }
0579 
0580     qCDebug(DIGIKAM_TESTS_LOG) << "write face embedding to spatial database with average" << timer.elapsed() /data.size() << "ms/faceEmbedding";
0581 }
0582 
0583 void Benchmark::verifyKNearestDb()
0584 {
0585     QDir dataDir(m_parser->value(QLatin1String("dataset")));
0586 
0587     QFile dataFile(dataDir.dirName() + QLatin1String(".json"));
0588 
0589     if (!dataFile.open(QIODevice::ReadOnly))
0590     {
0591         qWarning("Couldn't open data file.");
0592         return;
0593     }
0594 
0595     QByteArray saveData = dataFile.readAll();
0596     dataFile.close();
0597 
0598     QJsonDocument loadDoc(QJsonDocument::fromJson(saveData));
0599 
0600     QJsonArray data = loadDoc.array();
0601     int nbCorrect   = 0;
0602 
0603     QElapsedTimer timer;
0604     timer.start();
0605 
0606     for (int i = 0 ; i < data.size() ; ++i)
0607     {
0608         QJsonObject object               = data[i].toObject();
0609         std::vector<float> faceEmbedding = DNNFaceExtractor::decodeVector(object[QLatin1String("faceembedding")].toArray());
0610         int label                        = object[QLatin1String("id")].toInt();
0611 
0612         QMap<double, QVector<int> > closestNeighbors
0613 /*
0614             = m_recognizer->getClosestNodes(DNNFaceExtractor::vectortomat(faceEmbedding), 1.0, 5)
0615 */
0616             ;
0617 
0618         QMap<int, QVector<double> > votingGroups;
0619 
0620         for (QMap<double, QVector<int> >::const_iterator iter  = closestNeighbors.cbegin();
0621                                                          iter != closestNeighbors.cend();
0622                                                          ++iter)
0623         {
0624             for (int j = 0 ; j < iter.value().size() ; ++j)
0625             {
0626                 votingGroups[iter.value()[j]].append(iter.key());
0627             }
0628         }
0629 
0630         double maxScore = 0;
0631         int prediction  = -1;
0632 
0633         for (QMap<int, QVector<double> >::const_iterator group  = votingGroups.cbegin();
0634                                                          group != votingGroups.cend();
0635                                                          ++group)
0636         {
0637             double score = 0;
0638 
0639             for (int j = 0 ; j < group.value().size() ; ++j)
0640             {
0641                 score += (1 - group.value()[j]);
0642             }
0643 
0644             if (score > maxScore)
0645             {
0646                 maxScore   = score;
0647                 prediction = group.key();
0648             }
0649         }
0650 
0651         if (label == prediction)
0652         {
0653             ++nbCorrect;
0654         }
0655     }
0656 
0657     if (data.size() != 0)
0658     {
0659         qCDebug(DIGIKAM_TESTS_LOG) << "Accuracy"     << (float(nbCorrect) / data.size())*100
0660                  << "with average" << timer.elapsed()   / data.size()
0661                  << "ms/faceEmbedding";
0662     }
0663 }
0664 
0665 // --------------------------------------------------------
0666 
0667 QCommandLineParser* parseOptions(const QCoreApplication& app)
0668 {
0669     QCommandLineParser* const parser = new QCommandLineParser();
0670     parser->addOption(QCommandLineOption(QLatin1String("dataset"), QLatin1String("Data set folder"), QLatin1String("path relative to data folder")));
0671     parser->addOption(QCommandLineOption(QLatin1String("split"), QLatin1String("Split ratio"), QLatin1String("split ratio of training data")));
0672     parser->addHelpOption();
0673     parser->process(app);
0674 
0675     return parser;
0676 }
0677 
0678 int main(int argc, char** argv)
0679 {
0680     QCoreApplication app(argc, argv);
0681 
0682     app.setApplicationName(QString::fromLatin1("digikam"));          // for DB init.
0683 
0684     Benchmark benchmark;
0685     benchmark.m_parser = parseOptions(app);
0686 /*
0687     benchmark.saveData();
0688     benchmark.testWriteDb();
0689     benchmark.verifyKNearestDb();
0690 */
0691     benchmark.fetchData();
0692     benchmark.registerTrainingSet();
0693     benchmark.verifyTestSet();
0694 
0695     return 0;
0696 }
0697 
0698 #include "benchmark_recognition_cli.moc"