File indexing completed on 2024-05-12 15:58:13

0001 /*
0002  *  SPDX-FileCopyrightText: 2010 Edward Apap <schumifer@hotmail.com>
0003  *  SPDX-FileCopyrightText: 2011 José Luis Vergara Toloza <pentalis@gmail.com>
0004  *
0005  *  SPDX-License-Identifier: GPL-2.0-or-later
0006  */
0007 
0008 #ifndef KIS_CONVOLUTION_WORKER_FFT_H
0009 #define KIS_CONVOLUTION_WORKER_FFT_H
0010 
0011 #include <iostream>
0012 
0013 #include <KoChannelInfo.h>
0014 
0015 #include "kis_convolution_worker.h"
0016 #include "kis_math_toolbox.h"
0017 
0018 #include <QMutex>
0019 #include <QVector>
0020 #include <QTextStream>
0021 #include <QFile>
0022 #include <QDir>
0023 
0024 #include <fftw3.h>
0025 
0026 template<class _IteratorFactory_> class KisConvolutionWorkerFFT;
0027 class KisConvolutionWorkerFFTLock
0028 {
0029 private:
0030     static QMutex fftwMutex;
0031     template<class _IteratorFactory_> friend class KisConvolutionWorkerFFT;
0032 };
0033 
0034 QMutex KisConvolutionWorkerFFTLock::fftwMutex;
0035 
0036 
0037 template<class _IteratorFactory_>
0038 class KisConvolutionWorkerFFT : public KisConvolutionWorker<_IteratorFactory_>
0039 {
0040 public:
0041     KisConvolutionWorkerFFT(KisPainter *painter, KoUpdater *progress)
0042         : KisConvolutionWorker<_IteratorFactory_>(painter, progress)
0043     {
0044     }
0045 
0046     ~KisConvolutionWorkerFFT()
0047     {
0048     }
0049 
0050     void execute(const KisConvolutionKernelSP kernel,
0051                  const KisPaintDeviceSP src,
0052                  QPoint srcPos,
0053                  QPoint dstPos,
0054                  QSize areaSize,
0055                  const QRect &dataRect) override
0056     {
0057         // Make the area we cover as small as possible
0058         if (this->m_painter->selection())
0059         {
0060             QRect r = this->m_painter->selection()->selectedRect().intersected(QRect(srcPos, areaSize));
0061             dstPos += r.topLeft() - srcPos;
0062             srcPos = r.topLeft();
0063             areaSize = r.size();
0064         }
0065 
0066         if (areaSize.width() == 0 || areaSize.height() == 0)
0067             return;
0068 
0069         addToProgress(0);
0070         if (isInterrupted()) return;
0071 
0072         const quint32 halfKernelWidth = (kernel->width() - 1) / 2;
0073         const quint32 halfKernelHeight = (kernel->height() - 1) / 2;
0074 
0075         m_fftWidth = areaSize.width() + 4 * halfKernelWidth;
0076         m_fftHeight = areaSize.height() + 2 * halfKernelHeight;
0077 
0078         /**
0079          * FIXME: check whether this "optimization" is needed to
0080          * be uncommented. My tests showed about 30% better performance
0081          * when the line is commented out (DK).
0082          */
0083         //optimumDimensions(m_fftWidth, m_fftHeight);
0084 
0085         m_fftLength = m_fftHeight * (m_fftWidth / 2 + 1);
0086         m_extraMem = (m_fftWidth % 2) ? 1 : 2;
0087 
0088         // create and fill kernel
0089         m_kernelFFT = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * m_fftLength);
0090         memset(m_kernelFFT, 0, sizeof(fftw_complex) * m_fftLength);
0091         fftFillKernelMatrix(kernel, m_kernelFFT);
0092 
0093         // find out which channels need convolving
0094         QList<KoChannelInfo*> convChannelList = this->convolvableChannelList(src);
0095 
0096         m_channelFFT.resize(convChannelList.count());
0097         for (auto i = m_channelFFT.begin(); i != m_channelFFT.end(); ++i) {
0098             *i = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * m_fftLength);
0099         }
0100 
0101         const double kernelFactor = kernel->factor() ? kernel->factor() : 1;
0102         const double fftScale = 1.0 / (m_fftHeight * m_fftWidth) / kernelFactor;
0103 
0104         FFTInfo info (fftScale, convChannelList, kernel, this->m_painter->device()->colorSpace());
0105         int cacheRowStride = m_fftWidth + m_extraMem;
0106 
0107         fillCacheFromDevice(src,
0108                             QRect(srcPos.x() - halfKernelWidth,
0109                                   srcPos.y() - halfKernelHeight,
0110                                   m_fftWidth,
0111                                   m_fftHeight),
0112                             cacheRowStride,
0113                             info, dataRect);
0114 
0115         addToProgress(10);
0116         if (isInterrupted()) return;
0117 
0118         // calculate number off fft operations required for progress reporting
0119         const float progressPerFFT = (100 - 30) / (double)(convChannelList.count() * 2 + 1);
0120 
0121         // perform FFT
0122         fftw_plan fftwPlanForward, fftwPlanBackward;
0123 
0124         KisConvolutionWorkerFFTLock::fftwMutex.lock();
0125         fftwPlanForward = fftw_plan_dft_r2c_2d(m_fftHeight, m_fftWidth, (double*)m_kernelFFT, m_kernelFFT, FFTW_ESTIMATE);
0126         fftwPlanBackward = fftw_plan_dft_c2r_2d(m_fftHeight, m_fftWidth, m_kernelFFT, (double*)m_kernelFFT, FFTW_ESTIMATE);
0127         KisConvolutionWorkerFFTLock::fftwMutex.unlock();
0128 
0129         fftw_execute(fftwPlanForward);
0130         addToProgress(progressPerFFT);
0131         if (isInterrupted()) return;
0132 
0133         for (auto k = m_channelFFT.begin(); k != m_channelFFT.end(); ++k)
0134         {
0135             fftw_execute_dft_r2c(fftwPlanForward, (double*)(*k), *k);
0136             addToProgress(progressPerFFT);
0137             if (isInterrupted()) return;
0138 
0139             fftMultiply(*k, m_kernelFFT);
0140 
0141             fftw_execute_dft_c2r(fftwPlanBackward, *k, (double*)*k);
0142             addToProgress(progressPerFFT);
0143             if (isInterrupted()) return;
0144         }
0145 
0146         KisConvolutionWorkerFFTLock::fftwMutex.lock();
0147         fftw_destroy_plan(fftwPlanForward);
0148         fftw_destroy_plan(fftwPlanBackward);
0149         KisConvolutionWorkerFFTLock::fftwMutex.unlock();
0150 
0151 
0152         writeResultToDevice(QRect(dstPos.x(), dstPos.y(), areaSize.width(), areaSize.height()),
0153                             cacheRowStride, halfKernelWidth, halfKernelHeight,
0154                             info, dataRect);
0155 
0156         addToProgress(20);
0157         cleanUp();
0158     }
0159 
0160     struct FFTInfo {
0161         FFTInfo(qreal _fftScale,
0162                 const QList<KoChannelInfo*> &_convChannelList,
0163                 const KisConvolutionKernelSP kernel,
0164                 const KoColorSpace */*colorSpace*/)
0165             : fftScale(_fftScale),
0166               convChannelList(_convChannelList)
0167         {
0168             KisMathToolbox mathToolbox;
0169 
0170             for (int i = 0; i < convChannelList.count(); ++i) {
0171                 minClamp.append(mathToolbox.minChannelValue(convChannelList[i]));
0172                 maxClamp.append(mathToolbox.maxChannelValue(convChannelList[i]));
0173                 absoluteOffset.append((maxClamp[i] - minClamp[i]) * kernel->offset());
0174 
0175                 if (convChannelList[i]->channelType() == KoChannelInfo::ALPHA) {
0176                     alphaCachePos = i;
0177                     alphaRealPos = convChannelList[i]->pos();
0178                 }
0179             }
0180 
0181             toDoubleFuncPtr.resize(convChannelList.count());
0182             fromDoubleFuncPtr.resize(convChannelList.count());
0183             fromDoubleCheckNullFuncPtr.resize(convChannelList.count());
0184 
0185             bool result = mathToolbox.getToDoubleChannelPtr(convChannelList, toDoubleFuncPtr);
0186             result &= mathToolbox.getFromDoubleChannelPtr(convChannelList, fromDoubleFuncPtr);
0187             result &= mathToolbox.getFromDoubleCheckNullChannelPtr(convChannelList, fromDoubleCheckNullFuncPtr);
0188 
0189             KIS_ASSERT(result);
0190         }
0191 
0192         inline int numChannels() const {
0193             return convChannelList.size();
0194         }
0195 
0196 
0197         QVector<qreal> minClamp;
0198         QVector<qreal> maxClamp;
0199         QVector<qreal> absoluteOffset;
0200 
0201         qreal fftScale {0.0};
0202         QList<KoChannelInfo*> convChannelList;
0203 
0204         QVector<PtrToDouble> toDoubleFuncPtr;
0205         QVector<PtrFromDouble> fromDoubleFuncPtr;
0206         QVector<PtrFromDoubleCheckNull> fromDoubleCheckNullFuncPtr;
0207 
0208         int alphaCachePos {-1};
0209         int alphaRealPos {-1};
0210     };
0211 
0212     void fillCacheFromDevice(KisPaintDeviceSP src,
0213                              const QRect &rect,
0214                              const int cacheRowStride,
0215                              const FFTInfo &info,
0216                              const QRect &dataRect) {
0217 
0218         typename _IteratorFactory_::HLineConstIterator hitSrc =
0219             _IteratorFactory_::createHLineConstIterator(src,
0220                                                         rect.x(), rect.y(), rect.width(),
0221                                                         dataRect);
0222 
0223         const int channelCount = info.numChannels();
0224         QVector<double*> channelPtr(channelCount);
0225         const auto channelPtrBegin = channelPtr.begin();
0226         const auto channelPtrEnd = channelPtr.end();
0227 
0228         auto iFFt = m_channelFFT.constBegin();
0229         for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iFFt) {
0230             *i = (double*)*iFFt;
0231         }
0232 
0233         // prepare cache, reused in all loops
0234         QVector<double*> cacheRowStart(channelCount);
0235         const auto cacheRowStartBegin = cacheRowStart.begin();
0236 
0237         for (int y = 0; y < rect.height(); ++y) {
0238             // cache current channelPtr in cacheRowStart
0239             memcpy(cacheRowStart.data(), channelPtr.data(), channelCount * sizeof(double*));
0240 
0241             for (int x = 0; x < rect.width(); ++x) {
0242                 const quint8 *data = hitSrc->oldRawData();
0243 
0244                 // no alpha is a rare case, so just multiply by 1.0 in that case
0245                 double alphaValue = info.alphaRealPos >= 0 ?
0246                     info.toDoubleFuncPtr[info.alphaCachePos](data, info.alphaRealPos) : 1.0;
0247                 int k = 0;
0248                 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) {
0249                     if (k != info.alphaCachePos) {
0250                         const quint32 channelPos = info.convChannelList[k]->pos();
0251                         **i = info.toDoubleFuncPtr[k](data, channelPos) * alphaValue;
0252                     } else {
0253                         **i = alphaValue;
0254                     }
0255 
0256                     ++(*i);
0257                 }
0258 
0259                 hitSrc->nextPixel();
0260             }
0261 
0262             auto iRowStart = cacheRowStartBegin;
0263             for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iRowStart) {
0264                 *i = *iRowStart + cacheRowStride;
0265             }
0266 
0267             hitSrc->nextRow();
0268         }
0269 
0270     }
0271 
0272     inline void limitValue(qreal *value, qreal lowBound, qreal highBound) {
0273         if (*value > highBound) {
0274             *value = highBound;
0275         } else if (!(*value >= lowBound)) {  // value < lowBound or value == NaN
0276             // IEEE compliant comparisons with NaN are always false
0277             *value = lowBound;
0278         }
0279     }
0280 
0281     inline qreal writeAlphaFromCache(quint8* dstPtr,
0282                                      const quint32 channel,
0283                                      const FFTInfo &info,
0284                                      double* channelValuePtr,
0285                                      bool *dstValueIsNull) {
0286         qreal channelPixelValue;
0287 
0288         channelPixelValue = *channelValuePtr * info.fftScale + info.absoluteOffset[channel];
0289         limitValue(&channelPixelValue, info.minClamp[channel], info.maxClamp[channel]);
0290         info.fromDoubleCheckNullFuncPtr[channel](dstPtr, info.convChannelList[channel]->pos(), channelPixelValue, dstValueIsNull);
0291 
0292         return channelPixelValue;
0293     }
0294 
0295     template <bool additionalMultiplierActive>
0296     inline qreal writeOneChannelFromCache(quint8* dstPtr,
0297                                           const quint32 channel,
0298                                           const FFTInfo &info,
0299                                           double* channelValuePtr,
0300                                           const qreal additionalMultiplier = 0.0) {
0301         qreal channelPixelValue;
0302 
0303         if (additionalMultiplierActive) {
0304             channelPixelValue = *channelValuePtr * info.fftScale * additionalMultiplier + info.absoluteOffset[channel];
0305         } else {
0306             channelPixelValue = *channelValuePtr * info.fftScale + info.absoluteOffset[channel];
0307         }
0308 
0309         limitValue(&channelPixelValue, info.minClamp[channel], info.maxClamp[channel]);
0310 
0311         info.fromDoubleFuncPtr[channel](dstPtr, info.convChannelList[channel]->pos(), channelPixelValue);
0312 
0313         return channelPixelValue;
0314     }
0315 
0316     void writeResultToDevice(const QRect &rect,
0317                              const int cacheRowStride,
0318                              const int halfKernelWidth,
0319                              const int halfKernelHeight,
0320                              const FFTInfo &info,
0321                              const QRect &dataRect) {
0322 
0323         typename _IteratorFactory_::HLineIterator hitDst =
0324             _IteratorFactory_::createHLineIterator(this->m_painter->device(),
0325                                                    rect.x(), rect.y(), rect.width(),
0326                                                    dataRect);
0327 
0328         int initialOffset = cacheRowStride * halfKernelHeight + halfKernelWidth;
0329 
0330         const int channelCount = info.numChannels();
0331         QVector<double*> channelPtr(channelCount);
0332         const auto channelPtrBegin = channelPtr.begin();
0333         const auto channelPtrEnd = channelPtr.end();
0334 
0335         auto iFFt = m_channelFFT.constBegin();
0336         for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iFFt) {
0337             *i = (double*)*iFFt + initialOffset;
0338         }
0339 
0340         // prepare cache, reused in all loops
0341         QVector<double*> cacheRowStart(channelCount);
0342         const auto cacheRowStartBegin = cacheRowStart.begin();
0343 
0344         for (int y = 0; y < rect.height(); ++y) {
0345             // cache current channelPtr in cacheRowStart
0346             memcpy(cacheRowStart.data(), channelPtr.data(), channelCount * sizeof(double*));
0347 
0348             for (int x = 0; x < rect.width(); ++x) {
0349                 quint8 *dstPtr = hitDst->rawData();
0350 
0351                 if (info.alphaCachePos >= 0) {
0352                     bool alphaIsNullInDstSpace = false;
0353 
0354                     qreal alphaValue =
0355                         writeAlphaFromCache(dstPtr,
0356                                             info.alphaCachePos,
0357                                             info,
0358                                             channelPtr.at(info.alphaCachePos),
0359                                             &alphaIsNullInDstSpace);
0360 
0361                     if (!alphaIsNullInDstSpace &&
0362                         alphaValue > std::numeric_limits<qreal>::epsilon()) {
0363 
0364                         qreal alphaValueInv = 1.0 / alphaValue;
0365 
0366                         int k = 0;
0367                         for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) {
0368                             if (k != info.alphaCachePos) {
0369                                 writeOneChannelFromCache<true>(dstPtr,
0370                                                                k,
0371                                                                info,
0372                                                                *i,
0373                                                                alphaValueInv);
0374                             }
0375                             ++(*i);
0376                         }
0377                     } else {
0378                         int k = 0;
0379                         for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) {
0380                             if (k != info.alphaCachePos) {
0381                                 info.fromDoubleFuncPtr[k](dstPtr,
0382                                         info.convChannelList[k]->pos(),
0383                                         0.0);
0384                             }
0385                             ++(*i);
0386                         }
0387                     }
0388                 } else {
0389                     int k = 0;
0390                     for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) {
0391                         writeOneChannelFromCache<false>(dstPtr,
0392                                                         k,
0393                                                         info,
0394                                                         *i);
0395                        ++(*i);
0396                     }
0397                 }
0398 
0399                 hitDst->nextPixel();
0400             }
0401 
0402             auto iRowStart = cacheRowStartBegin;
0403             for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iRowStart) {
0404                 *i = *iRowStart + cacheRowStride;
0405             }
0406 
0407             hitDst->nextRow();
0408         }
0409 
0410     }
0411 
0412 private:
0413     void fftFillKernelMatrix(const KisConvolutionKernelSP kernel, fftw_complex *m_kernelFFT)
0414     {
0415         // find central item
0416         QPoint offset((kernel->width() - 1) / 2, (kernel->height() - 1) / 2);
0417 
0418         qint32 xShift = m_fftWidth - offset.x();
0419         qint32 yShift = m_fftHeight - offset.y();
0420 
0421         quint32 absXpos, absYpos;
0422 
0423         for (quint32 y = 0; y < kernel->height(); y++)
0424         {
0425             absYpos = y + yShift;
0426             if (absYpos >= m_fftHeight)
0427                 absYpos -= m_fftHeight;
0428 
0429             for (quint32 x = 0; x < kernel->width(); x++)
0430             {
0431                 absXpos = x + xShift;
0432                 if (absXpos >= m_fftWidth)
0433                     absXpos -= m_fftWidth;
0434 
0435                 ((double*)m_kernelFFT)[(m_fftWidth + m_extraMem) * absYpos + absXpos] = kernel->data()->coeff(y, x);
0436             }
0437         }
0438     }
0439 
0440     void fftMultiply(fftw_complex* channel, fftw_complex* kernel)
0441     {
0442         // perform complex multiplication
0443         fftw_complex *channelPtr = channel;
0444         fftw_complex *kernelPtr = kernel;
0445 
0446         fftw_complex tmp;
0447 
0448         for (quint32 pixelPos = 0; pixelPos < m_fftLength; ++pixelPos)
0449         {
0450             tmp[0] = ((*channelPtr)[0] * (*kernelPtr)[0]) - ((*channelPtr)[1] * (*kernelPtr)[1]);
0451             tmp[1] = ((*channelPtr)[0] * (*kernelPtr)[1]) + ((*channelPtr)[1] * (*kernelPtr)[0]);
0452 
0453             (*channelPtr)[0] = tmp[0];
0454             (*channelPtr)[1] = tmp[1];
0455 
0456             ++channelPtr;
0457             ++kernelPtr;
0458         }
0459     }
0460 
0461     void optimumDimensions(quint32& w, quint32& h)
0462     {
0463         // FFTW is most efficient when array size is a factor of 2, 3, 5 or 7
0464         quint32 optW = w, optH = h;
0465 
0466         while ((optW % 2 != 0) || (optW % 3 != 0) || (optW % 5 != 0) || (optW % 7 != 0))
0467             ++optW;
0468 
0469         while ((optH % 2 != 0) || (optH % 3 != 0) || (optH % 5 != 0) || (optH % 7 != 0))
0470             ++optH;
0471 
0472         quint32 optAreaW = optW * h;
0473         quint32 optAreaH = optH * w;
0474 
0475         if (optAreaW < optAreaH) {
0476             w = optW;
0477         }
0478         else {
0479             h = optH;
0480         }
0481     }
0482 
0483     void fftLogMatrix(double* channel, const QString &f)
0484     {
0485         KisConvolutionWorkerFFTLock::fftwMutex.lock();
0486         QString filename(QDir::homePath() + "/log_" + f + ".txt");
0487         dbgKrita << "Log File Name: " << filename;
0488         QFile file (filename);
0489         if (!file.open(QIODevice::WriteOnly | QIODevice::Truncate | QIODevice::Text))
0490         {
0491             dbgKrita << "Failed";
0492             KisConvolutionWorkerFFTLock::fftwMutex.unlock();
0493             return;
0494         }
0495 
0496         QTextStream in(&file);
0497         in.setCodec("UTF-8");
0498         for (quint32 y = 0; y < m_fftHeight; y++)
0499         {
0500             for (quint32 x = 0; x < m_fftWidth; x++)
0501             {
0502                 QString num = QString::number(channel[y * m_fftWidth + x]);
0503                 while (num.length() < 15)
0504                     num += " ";
0505 
0506                 in << num << " ";
0507             }
0508             in << "\n";
0509         }
0510         KisConvolutionWorkerFFTLock::fftwMutex.unlock();
0511     }
0512 
0513     void addToProgress(float amount)
0514     {
0515         m_currentProgress += amount;
0516 
0517         if (this->m_progress) {
0518             this->m_progress->setProgress((int)m_currentProgress);
0519         }
0520     }
0521 
0522     bool isInterrupted()
0523     {
0524         if (this->m_progress && this->m_progress->interrupted()) {
0525             cleanUp();
0526             return true;
0527         }
0528 
0529         return false;
0530     }
0531 
0532     void cleanUp()
0533     {
0534         // free kernel fft data
0535         if (m_kernelFFT) {
0536             fftw_free(m_kernelFFT);
0537         }
0538 
0539         Q_FOREACH (fftw_complex *channel, m_channelFFT) {
0540             fftw_free(channel);
0541         }
0542         m_channelFFT.clear();
0543     }
0544 private:
0545     quint32 m_fftWidth {0};
0546     quint32 m_fftHeight {0};
0547     quint32 m_fftLength {0};
0548     quint32 m_extraMem {0};
0549     float m_currentProgress {0.0};
0550 
0551     fftw_complex* m_kernelFFT {0};
0552     QVector<fftw_complex*> m_channelFFT;
0553 };
0554 
0555 #endif