File indexing completed on 2025-03-09 03:56:56

0001 /* ============================================================
0002  *
0003  * This file is a part of digiKam project
0004  * https://www.digikam.org
0005  *
0006  * Date        : 2020-05-20
0007  * Description : Testing tool for dnn face recognition of face engines
0008  *
0009  * SPDX-FileCopyrightText: 2020 by Nghia Duong <minhnghiaduong997 at gmail dot com>
0010  *
0011  * SPDX-License-Identifier: GPL-2.0-or-later
0012  *
0013  * ============================================================ */
0014 
0015 // Qt includes
0016 
0017 #include <QApplication>
0018 #include <QCommandLineParser>
0019 #include <QMainWindow>
0020 #include <QScrollArea>
0021 #include <QHBoxLayout>
0022 #include <QVBoxLayout>
0023 #include <QFormLayout>
0024 #include <QListWidget>
0025 #include <QListWidgetItem>
0026 #include <QDir>
0027 #include <QImage>
0028 #include <QElapsedTimer>
0029 #include <QLabel>
0030 #include <QPen>
0031 #include <QPainter>
0032 #include <QLineEdit>
0033 #include <QPushButton>
0034 #include <QHash>
0035 #include <QRandomGenerator>
0036 
0037 // Local includes
0038 
0039 #include "digikam_debug.h"
0040 #include "opencvdnnfacedetector.h"
0041 #include "facedetector.h"
0042 #include "dnnfaceextractor.h"
0043 #include "opencvdnnfacerecognizer.h"
0044 #include "identity.h"
0045 #include "coredbaccess.h"
0046 #include "dbengineparameters.h"
0047 #include "facialrecognition_wrapper.h"
0048 
0049 using namespace Digikam;
0050 
0051 // TODO: Recognition is incorrect where human are wearing glasses
0052 
0053 static QVector<QListWidgetItem*> splitData(const QDir& dataDir, float splitRatio ,
0054                                            QHash<QString, QVector<QImage> >& trainSet,
0055                                            QHash<QString, QVector<QImage> >& testSet)
0056 {
0057 
0058     QVector<QListWidgetItem*> imageItems;
0059 
0060     // Each subdirectory in data directory should match with a label
0061 
0062     QFileInfoList subDirs = dataDir.entryInfoList(QDir::NoDotAndDotDot | QDir::Dirs, QDir::Name);
0063 
0064     for (int i = 0 ; i < subDirs.size() ; ++i)
0065     {
0066         QDir subDir(subDirs[i].absoluteFilePath());
0067 
0068         QString label           = subDirs[i].fileName();
0069         QFileInfoList filesInfo = subDir.entryInfoList(QDir::Files | QDir::Readable);
0070 
0071         // suffle dataset
0072 
0073         QList<QFileInfo>::iterator it = filesInfo.begin();
0074         QList<QFileInfo>::iterator it1;
0075 
0076         for (int j = 0 ; j < filesInfo.size() ; ++j)
0077         {
0078             int inc = QRandomGenerator::global()->bounded(filesInfo.size());
0079 
0080             it1     = filesInfo.begin();
0081             it1    += inc;
0082 
0083             std::swap(*(it++), *(it1));
0084         }
0085 
0086         // split train/test
0087 
0088         for (int j = 0 ; i < filesInfo.size() ; ++j)
0089         {
0090             QImage img(filesInfo[j].absoluteFilePath());
0091 
0092             if (j < (filesInfo.size() * splitRatio))
0093             {
0094                 if (!img.isNull())
0095                 {
0096                     trainSet[label].append(img);
0097                     imageItems.append(new QListWidgetItem(QIcon(filesInfo[j].absoluteFilePath()), filesInfo[j].absoluteFilePath()));
0098                 }
0099             }
0100             else
0101             {
0102                 if (!img.isNull())
0103                 {
0104                     testSet[label].append(img);
0105                     imageItems.append(new QListWidgetItem(QIcon(filesInfo[j].absoluteFilePath()), filesInfo[j].absoluteFilePath()));
0106                 }
0107             }
0108         }
0109     }
0110 
0111     return imageItems;
0112 }
0113 
0114 
0115 class MainWindow : public QMainWindow
0116 {
0117     Q_OBJECT
0118 
0119 public:
0120 
0121     explicit MainWindow(const QDir& directory, QWidget* const parent = nullptr);
0122     ~MainWindow() override;
0123 
0124 private:
0125 
0126     QPixmap showCVMat(const cv::Mat& cvimage);
0127     QList<QRectF> detectFaces(const QString& imagePath);
0128     void extractFaces(const QImage& img, QImage& imgScaled, const QList<QRectF>& faces);
0129 
0130     QWidget* setupFullImageArea();
0131     QWidget* setupCroppedFaceArea();
0132     QWidget* setupPreprocessedFaceArea();
0133     QWidget* setupAlignedFaceArea();
0134     QWidget* setupControlPanel();
0135     QWidget* setupImageList(const QDir& directory);
0136 
0137 private Q_SLOTS:
0138 
0139     void slotDetectFaces(const QListWidgetItem* imageItem);
0140     void slotSaveIdentity();
0141     void slotIdentify(int index);
0142 
0143 private:
0144 
0145     OpenCVDNNFaceDetector*    m_detector;
0146     FacialRecognitionWrapper* recognitionWrapper;
0147     OpenCVDNNFaceRecognizer*  m_recognizer;
0148     DNNFaceExtractor*         m_extractor;
0149     QVector<cv::Mat>          m_preprocessedFaces;
0150     Identity                  m_currentIdenity;
0151 
0152 
0153     QLabel*                   m_fullImage;
0154     QListWidget*              m_imageListView;
0155     QListWidget*              m_croppedfaceList;
0156     //QVBoxLayout*              m_preprocessedList;
0157     QListWidget*              m_preprocessedList;
0158     QListWidget*              m_alignedList;
0159 
0160     // control panel
0161     QLineEdit*                m_imageLabel;
0162     QLabel*                   m_similarityLabel;
0163     QLabel*                   m_recognizationInfo;
0164     QPushButton*              m_applyButton;
0165 };
0166 
0167 MainWindow::MainWindow(const QDir &directory, QWidget* const parent)
0168     : QMainWindow(parent)
0169 {
0170     DbEngineParameters prm = DbEngineParameters::parametersFromConfig();
0171     CoreDbAccess::setParameters(prm, CoreDbAccess::MainApplication);
0172 
0173     setWindowTitle(QLatin1String("Face recognition Test"));
0174 
0175     recognitionWrapper     = new FacialRecognitionWrapper();
0176 
0177     m_detector             = new OpenCVDNNFaceDetector(DetectorNNModel::YOLO);
0178     m_recognizer           = new OpenCVDNNFaceRecognizer(OpenCVDNNFaceRecognizer::Tree);
0179     m_extractor            = new DNNFaceExtractor();
0180 
0181     // Image erea
0182 
0183     QWidget*     const imageArea        = new QWidget(this);
0184     QHBoxLayout*       processingLayout = new QHBoxLayout(imageArea);
0185 
0186     processingLayout->addWidget(setupFullImageArea());
0187     processingLayout->addWidget(setupCroppedFaceArea());
0188     processingLayout->addWidget(setupPreprocessedFaceArea());
0189     processingLayout->addWidget(setupAlignedFaceArea());
0190     processingLayout->addWidget(setupControlPanel());
0191 
0192     QSizePolicy spHigh(QSizePolicy::Preferred, QSizePolicy::Preferred);
0193     spHigh.setVerticalPolicy(QSizePolicy::Expanding);
0194     imageArea->setSizePolicy(spHigh);
0195 
0196     QWidget*     const mainWidget = new QWidget(this);
0197     QVBoxLayout* const mainLayout = new QVBoxLayout(mainWidget);
0198     mainLayout->addWidget(imageArea);
0199     mainLayout->addWidget(setupImageList(directory));
0200 
0201     setCentralWidget(mainWidget);
0202 }
0203 
0204 MainWindow::~MainWindow()
0205 {
0206     delete recognitionWrapper;
0207     delete m_detector;
0208     delete m_recognizer;
0209     delete m_extractor;
0210 
0211     delete m_fullImage;
0212     delete m_imageListView;
0213     delete m_croppedfaceList;
0214     delete m_preprocessedList;
0215     delete m_alignedList;
0216 
0217     delete m_recognizationInfo;
0218     delete m_imageLabel;
0219     delete m_similarityLabel;
0220     delete m_applyButton;
0221 }
0222 
0223 void MainWindow::slotDetectFaces(const QListWidgetItem* imageItem)
0224 {
0225     QString imagePath = imageItem->text();
0226     qCDebug(DIGIKAM_TESTS_LOG) << "Loading " << imagePath;
0227     QImage img(imagePath);
0228     QImage imgScaled(img.scaled(416, 416, Qt::KeepAspectRatio));
0229 
0230     m_imageLabel->clear();
0231     m_recognizationInfo->clear();
0232 
0233     disconnect(m_alignedList, &QListWidget::currentRowChanged,
0234                this, &MainWindow::slotIdentify);
0235 
0236     // clear faces layout
0237 
0238     QListWidgetItem* wItem = nullptr;
0239 
0240     while ((wItem = m_croppedfaceList->item(0)) != nullptr)
0241     {
0242         delete wItem;
0243     }
0244 
0245     while ((wItem = m_preprocessedList->item(0)) != nullptr)
0246     {
0247         delete wItem;
0248     }
0249 
0250     while ((wItem = m_alignedList->item(0)) != nullptr)
0251     {
0252         delete wItem;
0253     }
0254 
0255     QList<QRectF> faces = detectFaces(imagePath);
0256 
0257     extractFaces(img, imgScaled, faces);
0258 
0259     // Only setPixmap after finishing drawing bboxes around detected faces
0260     m_fullImage->setPixmap(QPixmap::fromImage(imgScaled));
0261 
0262     connect(m_alignedList, &QListWidget::currentRowChanged,
0263             this, &MainWindow::slotIdentify);
0264 }
0265 
0266 QPixmap MainWindow::showCVMat(const cv::Mat& cvimage)
0267 {
0268     if (cvimage.cols*cvimage.rows != 0)
0269     {
0270         cv::Mat rgb;
0271         QPixmap p;
0272         cv::cvtColor(cvimage, rgb, (-2*cvimage.channels()+10));
0273         p.convertFromImage(QImage(rgb.data, rgb.cols, rgb.rows, QImage::Format_RGB888));
0274 
0275         return p;
0276     }
0277 
0278     return QPixmap();
0279 }
0280 
0281 QList<QRectF> MainWindow::detectFaces(const QString& imagePath)
0282 {
0283     QImage img(imagePath);
0284 
0285     QList<QRectF> faces;
0286 
0287     try
0288     {
0289         QElapsedTimer timer;
0290         unsigned int elapsedDetection = 0;
0291         timer.start();
0292 
0293         // NOTE detection with filePath won't work when format is not standard
0294         // NOTE unexpected behaviour with detecFaces(const QString&)
0295         cv::Size paddedSize(0, 0);
0296         cv::Mat cvImage       = m_detector->prepareForDetection(img, paddedSize);
0297         QList<QRect> absRects = m_detector->detectFaces(cvImage, paddedSize);
0298         faces                 = FaceDetector::toRelativeRects(absRects,
0299                                                               QSize(cvImage.cols - 2*paddedSize.width,
0300                                                               cvImage.rows - 2*paddedSize.height));
0301         elapsedDetection = timer.elapsed();
0302 
0303         qCDebug(DIGIKAM_TESTS_LOG) << "(Input CV) Found " << absRects.size() << " faces, in " << elapsedDetection << "ms";
0304     }
0305     catch (cv::Exception& e)
0306     {
0307         qCWarning(DIGIKAM_TESTS_LOG) << "cv::Exception:" << e.what();
0308     }
0309     catch (...)
0310     {
0311         qCWarning(DIGIKAM_TESTS_LOG) << "Default exception from OpenCV";
0312     }
0313 
0314     return faces;
0315 }
0316 
0317 void MainWindow::extractFaces(const QImage& img, QImage& imgScaled, const QList<QRectF>& faces)
0318 {
0319     if (faces.isEmpty())
0320     {
0321         qCWarning(DIGIKAM_TESTS_LOG) << "No face detected";
0322         return;
0323     }
0324 
0325     qCDebug(DIGIKAM_TESTS_LOG) << "Coordinates of detected faces : ";
0326 
0327     Q_FOREACH (const QRectF& r, faces)
0328     {
0329         qCDebug(DIGIKAM_TESTS_LOG) << r;
0330     }
0331 
0332     QPainter painter(&imgScaled);
0333     QPen paintPen(Qt::green);
0334     paintPen.setWidth(1);
0335     painter.setPen(paintPen);
0336 
0337     m_preprocessedFaces.clear();
0338 
0339     Q_FOREACH (const QRectF& rr, faces)
0340     {
0341         QRect rectDraw      = FaceDetector::toAbsoluteRect(rr, imgScaled.size());
0342         QRect rect          = FaceDetector::toAbsoluteRect(rr, img.size());
0343         QImage part         = img.copy(rect);
0344 
0345         // Show cropped faces
0346 
0347         QIcon croppedFace(QPixmap::fromImage(part.scaled(qMin(img.size().width(), 100),
0348                                                          qMin(img.size().width(), 100),
0349                                                          Qt::KeepAspectRatio)));
0350 
0351         m_croppedfaceList->addItem(new QListWidgetItem(croppedFace, QLatin1String("")));
0352         painter.drawRect(rectDraw);
0353 
0354         // Show preprocessed faces
0355 
0356         cv::Mat cvPreprocessedFace = m_recognizer->prepareForRecognition(part);
0357         m_preprocessedList->addItem(new QListWidgetItem(QIcon(showCVMat(cvPreprocessedFace)), QLatin1String("")));
0358 
0359         m_preprocessedFaces << cvPreprocessedFace;
0360 
0361         // Show aligned faces
0362 
0363         cv::Mat cvAlignedFace = m_extractor->alignFace(cvPreprocessedFace);
0364         m_alignedList->addItem(new QListWidgetItem(QIcon(showCVMat(cvAlignedFace)), QLatin1String("")));
0365     }
0366 }
0367 
0368 QWidget* MainWindow::setupFullImageArea()
0369 {
0370     m_fullImage = new QLabel;
0371     m_fullImage->setScaledContents(true);
0372 
0373     QSizePolicy spImage(QSizePolicy::Preferred, QSizePolicy::Expanding);
0374     spImage.setHorizontalStretch(3);
0375     m_fullImage->setSizePolicy(spImage);
0376 
0377     return m_fullImage;
0378 }
0379 
0380 QWidget* MainWindow::setupCroppedFaceArea()
0381 {
0382     QScrollArea* const facesArea = new QScrollArea(this);
0383     facesArea->setWidgetResizable(true);
0384     facesArea->setAlignment(Qt::AlignRight);
0385     facesArea->setVerticalScrollBarPolicy(Qt::ScrollBarAlwaysOn);
0386     facesArea->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
0387 
0388     QSizePolicy spImage(QSizePolicy::Preferred, QSizePolicy::Expanding);
0389     spImage.setHorizontalStretch(1);
0390 
0391     facesArea->setSizePolicy(spImage);
0392 
0393     m_croppedfaceList = new QListWidget(this);
0394 
0395     m_croppedfaceList->setViewMode(QListView::IconMode);
0396     m_croppedfaceList->setIconSize(QSize(200, 150));
0397     m_croppedfaceList->setResizeMode(QListWidget::Adjust);
0398     m_croppedfaceList->setFlow(QListView::TopToBottom);
0399     m_croppedfaceList->setWrapping(false);
0400     m_croppedfaceList->setDragEnabled(false);
0401 
0402     facesArea->setWidget(m_croppedfaceList);
0403 
0404     return facesArea;
0405 }
0406 
0407 QWidget* MainWindow::setupPreprocessedFaceArea()
0408 {
0409     // preprocessed face area
0410 
0411     QScrollArea* const preprocessedFacesArea = new QScrollArea(this);
0412     preprocessedFacesArea->setWidgetResizable(true);
0413     preprocessedFacesArea->setAlignment(Qt::AlignRight);
0414     preprocessedFacesArea->setVerticalScrollBarPolicy(Qt::ScrollBarAlwaysOn);
0415     preprocessedFacesArea->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
0416 
0417     QSizePolicy spImage(QSizePolicy::Preferred, QSizePolicy::Expanding);
0418     spImage.setHorizontalStretch(1);
0419 
0420     preprocessedFacesArea->setSizePolicy(spImage);
0421 
0422     m_preprocessedList = new QListWidget();
0423 
0424     m_preprocessedList->setViewMode(QListView::IconMode);
0425     m_preprocessedList->setIconSize(QSize(200, 150));
0426     m_preprocessedList->setResizeMode(QListWidget::Adjust);
0427     m_preprocessedList->setFlow(QListView::TopToBottom);
0428     m_preprocessedList->setWrapping(false);
0429     m_preprocessedList->setDragEnabled(false);
0430 
0431     preprocessedFacesArea->setWidget(m_preprocessedList);
0432 
0433     return preprocessedFacesArea;
0434 }
0435 
0436 QWidget* MainWindow::setupAlignedFaceArea()
0437 {
0438     // aligned face area
0439 
0440     QScrollArea* const alignedFacesArea = new QScrollArea(this);
0441     alignedFacesArea->setWidgetResizable(true);
0442     alignedFacesArea->setAlignment(Qt::AlignRight);
0443     alignedFacesArea->setVerticalScrollBarPolicy(Qt::ScrollBarAlwaysOn);
0444     alignedFacesArea->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
0445 
0446     QSizePolicy spImage(QSizePolicy::Preferred, QSizePolicy::Expanding);
0447     spImage.setHorizontalStretch(1);
0448 
0449     alignedFacesArea->setSizePolicy(spImage);
0450 
0451     m_alignedList = new QListWidget(this);
0452 
0453     m_alignedList->setViewMode(QListView::IconMode);
0454     m_alignedList->setIconSize(QSize(200, 150));
0455     m_alignedList->setResizeMode(QListWidget::Adjust);
0456     m_alignedList->setFlow(QListView::TopToBottom);
0457     m_alignedList->setWrapping(false);
0458     m_alignedList->setDragEnabled(false);
0459 
0460     alignedFacesArea->setWidget(m_alignedList);
0461 
0462     return alignedFacesArea;
0463 }
0464 
0465 QWidget* MainWindow::setupControlPanel()
0466 {
0467     QWidget* const controlPanel = new QWidget();
0468 
0469     QSizePolicy spImage(QSizePolicy::Preferred, QSizePolicy::Expanding);
0470     spImage.setHorizontalStretch(2);
0471 
0472     controlPanel->setSizePolicy(spImage);
0473 
0474     m_recognizationInfo       = new QLabel(this);
0475     m_imageLabel              = new QLineEdit(this);
0476     m_similarityLabel         = new QLabel(this);
0477     m_applyButton             = new QPushButton(this);
0478     m_applyButton->setText(QLatin1String("Apply"));
0479 
0480     QFormLayout* const layout = new QFormLayout(this);
0481     layout->addRow(new QLabel(QLatin1String("Recognized :")), m_recognizationInfo);
0482     layout->addRow(new QLabel(QLatin1String("Identity :")),  m_imageLabel);
0483     layout->addRow(new QLabel(QLatin1String("Similarity distance :")), m_similarityLabel);
0484     layout->insertRow(3, m_applyButton);
0485 
0486     controlPanel->setLayout(layout);
0487 
0488     connect(m_applyButton, &QPushButton::clicked,
0489             this, &MainWindow::slotSaveIdentity);
0490 
0491     return controlPanel;
0492 }
0493 
0494 QWidget* MainWindow::setupImageList(const QDir& directory)
0495 {
0496     // Itemlist erea
0497 
0498     QScrollArea* const itemsArea = new QScrollArea;
0499     itemsArea->setWidgetResizable(true);
0500     itemsArea->setAlignment(Qt::AlignBottom);
0501     itemsArea->setVerticalScrollBarPolicy(Qt::ScrollBarAlwaysOff);
0502     itemsArea->setHorizontalScrollBarPolicy(Qt::ScrollBarAlwaysOn);
0503 
0504     QSizePolicy spLow(QSizePolicy::Preferred, QSizePolicy::Fixed);
0505     itemsArea->setSizePolicy(spLow);
0506 
0507     m_imageListView = new QListWidget(this);
0508     m_imageListView->setViewMode(QListView::IconMode);
0509     m_imageListView->setIconSize(QSize(200, 150));
0510     m_imageListView->setResizeMode(QListWidget::Adjust);
0511     m_imageListView->setFlow(QListView::LeftToRight);
0512     m_imageListView->setWrapping(false);
0513     m_imageListView->setDragEnabled(false);
0514 
0515     QHash<QString, QVector<QImage> > trainSet, testSet;
0516 
0517     QVector<QListWidgetItem*> items = splitData(directory, 0.5, trainSet, testSet);
0518 
0519     for (int i = 0 ; i < items.size() ; ++i)
0520     {
0521         m_imageListView->addItem(items[i]);
0522     }
0523 
0524     connect(m_imageListView, &QListWidget::currentItemChanged,
0525             this, &MainWindow::slotDetectFaces);
0526 
0527     itemsArea->setWidget(m_imageListView);
0528 
0529     return itemsArea;
0530 }
0531 
0532 void MainWindow::slotIdentify(int /*index*/)
0533 {
0534     // TODO : fix this
0535     //m_currentIdenity = m_recognizer->findIdenity(m_preprocessedFaces[index]);
0536 
0537     if (m_currentIdenity.isNull())
0538     {
0539         m_recognizationInfo->setText(QLatin1String("Cannot recognized"));
0540     }
0541     else
0542     {
0543         m_recognizationInfo->setText(QLatin1String("Recognized"));
0544         m_imageLabel->setText(m_currentIdenity.attribute(QLatin1String("fullName")));
0545     }
0546 }
0547 
0548 void MainWindow::slotSaveIdentity()
0549 {
0550     qCDebug(DIGIKAM_TESTS_LOG) << "assign identity" << m_imageLabel->text();
0551     m_currentIdenity.setAttribute(QLatin1String("fullName"), m_imageLabel->text());
0552 
0553     //m_recognizer->saveIdentity(m_currentIdenity, false);
0554 }
0555 
0556 QCommandLineParser* parseOptions(const QCoreApplication& app)
0557 {
0558     QCommandLineParser* const parser = new QCommandLineParser();
0559     parser->addOption(QCommandLineOption(QLatin1String("dataset"), QLatin1String("Data set folder"), QLatin1String("path relative to data folder")));
0560     parser->addHelpOption();
0561     parser->process(app);
0562 
0563     return parser;
0564 }
0565 
0566 int main(int argc, char* argv[])
0567 {
0568     QApplication app(argc, argv);
0569     app.setApplicationName(QString::fromLatin1("digikam"));
0570 
0571     // Options for commandline parser
0572 
0573    QCommandLineParser* const parser = parseOptions(app);
0574 
0575    if (!parser->isSet(QLatin1String("dataset")))
0576    {
0577        qWarning("Data set is not set !!!");
0578 
0579        return 1;
0580    }
0581 
0582    QDir dataset(parser->value(QLatin1String("dataset")));
0583 
0584    MainWindow* const window = new MainWindow(dataset, nullptr);
0585    window->show();
0586 
0587    return app.exec();
0588 }
0589 
0590 #include "recognition_gui.moc"