File indexing completed on 2024-05-12 15:58:13
0001 /* 0002 * SPDX-FileCopyrightText: 2010 Edward Apap <schumifer@hotmail.com> 0003 * SPDX-FileCopyrightText: 2011 José Luis Vergara Toloza <pentalis@gmail.com> 0004 * 0005 * SPDX-License-Identifier: GPL-2.0-or-later 0006 */ 0007 0008 #ifndef KIS_CONVOLUTION_WORKER_FFT_H 0009 #define KIS_CONVOLUTION_WORKER_FFT_H 0010 0011 #include <iostream> 0012 0013 #include <KoChannelInfo.h> 0014 0015 #include "kis_convolution_worker.h" 0016 #include "kis_math_toolbox.h" 0017 0018 #include <QMutex> 0019 #include <QVector> 0020 #include <QTextStream> 0021 #include <QFile> 0022 #include <QDir> 0023 0024 #include <fftw3.h> 0025 0026 template<class _IteratorFactory_> class KisConvolutionWorkerFFT; 0027 class KisConvolutionWorkerFFTLock 0028 { 0029 private: 0030 static QMutex fftwMutex; 0031 template<class _IteratorFactory_> friend class KisConvolutionWorkerFFT; 0032 }; 0033 0034 QMutex KisConvolutionWorkerFFTLock::fftwMutex; 0035 0036 0037 template<class _IteratorFactory_> 0038 class KisConvolutionWorkerFFT : public KisConvolutionWorker<_IteratorFactory_> 0039 { 0040 public: 0041 KisConvolutionWorkerFFT(KisPainter *painter, KoUpdater *progress) 0042 : KisConvolutionWorker<_IteratorFactory_>(painter, progress) 0043 { 0044 } 0045 0046 ~KisConvolutionWorkerFFT() 0047 { 0048 } 0049 0050 void execute(const KisConvolutionKernelSP kernel, 0051 const KisPaintDeviceSP src, 0052 QPoint srcPos, 0053 QPoint dstPos, 0054 QSize areaSize, 0055 const QRect &dataRect) override 0056 { 0057 // Make the area we cover as small as possible 0058 if (this->m_painter->selection()) 0059 { 0060 QRect r = this->m_painter->selection()->selectedRect().intersected(QRect(srcPos, areaSize)); 0061 dstPos += r.topLeft() - srcPos; 0062 srcPos = r.topLeft(); 0063 areaSize = r.size(); 0064 } 0065 0066 if (areaSize.width() == 0 || areaSize.height() == 0) 0067 return; 0068 0069 addToProgress(0); 0070 if (isInterrupted()) return; 0071 0072 const quint32 halfKernelWidth = (kernel->width() - 1) / 2; 0073 const quint32 halfKernelHeight = (kernel->height() - 1) / 2; 0074 0075 m_fftWidth = areaSize.width() + 4 * halfKernelWidth; 0076 m_fftHeight = areaSize.height() + 2 * halfKernelHeight; 0077 0078 /** 0079 * FIXME: check whether this "optimization" is needed to 0080 * be uncommented. My tests showed about 30% better performance 0081 * when the line is commented out (DK). 0082 */ 0083 //optimumDimensions(m_fftWidth, m_fftHeight); 0084 0085 m_fftLength = m_fftHeight * (m_fftWidth / 2 + 1); 0086 m_extraMem = (m_fftWidth % 2) ? 1 : 2; 0087 0088 // create and fill kernel 0089 m_kernelFFT = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * m_fftLength); 0090 memset(m_kernelFFT, 0, sizeof(fftw_complex) * m_fftLength); 0091 fftFillKernelMatrix(kernel, m_kernelFFT); 0092 0093 // find out which channels need convolving 0094 QList<KoChannelInfo*> convChannelList = this->convolvableChannelList(src); 0095 0096 m_channelFFT.resize(convChannelList.count()); 0097 for (auto i = m_channelFFT.begin(); i != m_channelFFT.end(); ++i) { 0098 *i = (fftw_complex *)fftw_malloc(sizeof(fftw_complex) * m_fftLength); 0099 } 0100 0101 const double kernelFactor = kernel->factor() ? kernel->factor() : 1; 0102 const double fftScale = 1.0 / (m_fftHeight * m_fftWidth) / kernelFactor; 0103 0104 FFTInfo info (fftScale, convChannelList, kernel, this->m_painter->device()->colorSpace()); 0105 int cacheRowStride = m_fftWidth + m_extraMem; 0106 0107 fillCacheFromDevice(src, 0108 QRect(srcPos.x() - halfKernelWidth, 0109 srcPos.y() - halfKernelHeight, 0110 m_fftWidth, 0111 m_fftHeight), 0112 cacheRowStride, 0113 info, dataRect); 0114 0115 addToProgress(10); 0116 if (isInterrupted()) return; 0117 0118 // calculate number off fft operations required for progress reporting 0119 const float progressPerFFT = (100 - 30) / (double)(convChannelList.count() * 2 + 1); 0120 0121 // perform FFT 0122 fftw_plan fftwPlanForward, fftwPlanBackward; 0123 0124 KisConvolutionWorkerFFTLock::fftwMutex.lock(); 0125 fftwPlanForward = fftw_plan_dft_r2c_2d(m_fftHeight, m_fftWidth, (double*)m_kernelFFT, m_kernelFFT, FFTW_ESTIMATE); 0126 fftwPlanBackward = fftw_plan_dft_c2r_2d(m_fftHeight, m_fftWidth, m_kernelFFT, (double*)m_kernelFFT, FFTW_ESTIMATE); 0127 KisConvolutionWorkerFFTLock::fftwMutex.unlock(); 0128 0129 fftw_execute(fftwPlanForward); 0130 addToProgress(progressPerFFT); 0131 if (isInterrupted()) return; 0132 0133 for (auto k = m_channelFFT.begin(); k != m_channelFFT.end(); ++k) 0134 { 0135 fftw_execute_dft_r2c(fftwPlanForward, (double*)(*k), *k); 0136 addToProgress(progressPerFFT); 0137 if (isInterrupted()) return; 0138 0139 fftMultiply(*k, m_kernelFFT); 0140 0141 fftw_execute_dft_c2r(fftwPlanBackward, *k, (double*)*k); 0142 addToProgress(progressPerFFT); 0143 if (isInterrupted()) return; 0144 } 0145 0146 KisConvolutionWorkerFFTLock::fftwMutex.lock(); 0147 fftw_destroy_plan(fftwPlanForward); 0148 fftw_destroy_plan(fftwPlanBackward); 0149 KisConvolutionWorkerFFTLock::fftwMutex.unlock(); 0150 0151 0152 writeResultToDevice(QRect(dstPos.x(), dstPos.y(), areaSize.width(), areaSize.height()), 0153 cacheRowStride, halfKernelWidth, halfKernelHeight, 0154 info, dataRect); 0155 0156 addToProgress(20); 0157 cleanUp(); 0158 } 0159 0160 struct FFTInfo { 0161 FFTInfo(qreal _fftScale, 0162 const QList<KoChannelInfo*> &_convChannelList, 0163 const KisConvolutionKernelSP kernel, 0164 const KoColorSpace */*colorSpace*/) 0165 : fftScale(_fftScale), 0166 convChannelList(_convChannelList) 0167 { 0168 KisMathToolbox mathToolbox; 0169 0170 for (int i = 0; i < convChannelList.count(); ++i) { 0171 minClamp.append(mathToolbox.minChannelValue(convChannelList[i])); 0172 maxClamp.append(mathToolbox.maxChannelValue(convChannelList[i])); 0173 absoluteOffset.append((maxClamp[i] - minClamp[i]) * kernel->offset()); 0174 0175 if (convChannelList[i]->channelType() == KoChannelInfo::ALPHA) { 0176 alphaCachePos = i; 0177 alphaRealPos = convChannelList[i]->pos(); 0178 } 0179 } 0180 0181 toDoubleFuncPtr.resize(convChannelList.count()); 0182 fromDoubleFuncPtr.resize(convChannelList.count()); 0183 fromDoubleCheckNullFuncPtr.resize(convChannelList.count()); 0184 0185 bool result = mathToolbox.getToDoubleChannelPtr(convChannelList, toDoubleFuncPtr); 0186 result &= mathToolbox.getFromDoubleChannelPtr(convChannelList, fromDoubleFuncPtr); 0187 result &= mathToolbox.getFromDoubleCheckNullChannelPtr(convChannelList, fromDoubleCheckNullFuncPtr); 0188 0189 KIS_ASSERT(result); 0190 } 0191 0192 inline int numChannels() const { 0193 return convChannelList.size(); 0194 } 0195 0196 0197 QVector<qreal> minClamp; 0198 QVector<qreal> maxClamp; 0199 QVector<qreal> absoluteOffset; 0200 0201 qreal fftScale {0.0}; 0202 QList<KoChannelInfo*> convChannelList; 0203 0204 QVector<PtrToDouble> toDoubleFuncPtr; 0205 QVector<PtrFromDouble> fromDoubleFuncPtr; 0206 QVector<PtrFromDoubleCheckNull> fromDoubleCheckNullFuncPtr; 0207 0208 int alphaCachePos {-1}; 0209 int alphaRealPos {-1}; 0210 }; 0211 0212 void fillCacheFromDevice(KisPaintDeviceSP src, 0213 const QRect &rect, 0214 const int cacheRowStride, 0215 const FFTInfo &info, 0216 const QRect &dataRect) { 0217 0218 typename _IteratorFactory_::HLineConstIterator hitSrc = 0219 _IteratorFactory_::createHLineConstIterator(src, 0220 rect.x(), rect.y(), rect.width(), 0221 dataRect); 0222 0223 const int channelCount = info.numChannels(); 0224 QVector<double*> channelPtr(channelCount); 0225 const auto channelPtrBegin = channelPtr.begin(); 0226 const auto channelPtrEnd = channelPtr.end(); 0227 0228 auto iFFt = m_channelFFT.constBegin(); 0229 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iFFt) { 0230 *i = (double*)*iFFt; 0231 } 0232 0233 // prepare cache, reused in all loops 0234 QVector<double*> cacheRowStart(channelCount); 0235 const auto cacheRowStartBegin = cacheRowStart.begin(); 0236 0237 for (int y = 0; y < rect.height(); ++y) { 0238 // cache current channelPtr in cacheRowStart 0239 memcpy(cacheRowStart.data(), channelPtr.data(), channelCount * sizeof(double*)); 0240 0241 for (int x = 0; x < rect.width(); ++x) { 0242 const quint8 *data = hitSrc->oldRawData(); 0243 0244 // no alpha is a rare case, so just multiply by 1.0 in that case 0245 double alphaValue = info.alphaRealPos >= 0 ? 0246 info.toDoubleFuncPtr[info.alphaCachePos](data, info.alphaRealPos) : 1.0; 0247 int k = 0; 0248 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) { 0249 if (k != info.alphaCachePos) { 0250 const quint32 channelPos = info.convChannelList[k]->pos(); 0251 **i = info.toDoubleFuncPtr[k](data, channelPos) * alphaValue; 0252 } else { 0253 **i = alphaValue; 0254 } 0255 0256 ++(*i); 0257 } 0258 0259 hitSrc->nextPixel(); 0260 } 0261 0262 auto iRowStart = cacheRowStartBegin; 0263 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iRowStart) { 0264 *i = *iRowStart + cacheRowStride; 0265 } 0266 0267 hitSrc->nextRow(); 0268 } 0269 0270 } 0271 0272 inline void limitValue(qreal *value, qreal lowBound, qreal highBound) { 0273 if (*value > highBound) { 0274 *value = highBound; 0275 } else if (!(*value >= lowBound)) { // value < lowBound or value == NaN 0276 // IEEE compliant comparisons with NaN are always false 0277 *value = lowBound; 0278 } 0279 } 0280 0281 inline qreal writeAlphaFromCache(quint8* dstPtr, 0282 const quint32 channel, 0283 const FFTInfo &info, 0284 double* channelValuePtr, 0285 bool *dstValueIsNull) { 0286 qreal channelPixelValue; 0287 0288 channelPixelValue = *channelValuePtr * info.fftScale + info.absoluteOffset[channel]; 0289 limitValue(&channelPixelValue, info.minClamp[channel], info.maxClamp[channel]); 0290 info.fromDoubleCheckNullFuncPtr[channel](dstPtr, info.convChannelList[channel]->pos(), channelPixelValue, dstValueIsNull); 0291 0292 return channelPixelValue; 0293 } 0294 0295 template <bool additionalMultiplierActive> 0296 inline qreal writeOneChannelFromCache(quint8* dstPtr, 0297 const quint32 channel, 0298 const FFTInfo &info, 0299 double* channelValuePtr, 0300 const qreal additionalMultiplier = 0.0) { 0301 qreal channelPixelValue; 0302 0303 if (additionalMultiplierActive) { 0304 channelPixelValue = *channelValuePtr * info.fftScale * additionalMultiplier + info.absoluteOffset[channel]; 0305 } else { 0306 channelPixelValue = *channelValuePtr * info.fftScale + info.absoluteOffset[channel]; 0307 } 0308 0309 limitValue(&channelPixelValue, info.minClamp[channel], info.maxClamp[channel]); 0310 0311 info.fromDoubleFuncPtr[channel](dstPtr, info.convChannelList[channel]->pos(), channelPixelValue); 0312 0313 return channelPixelValue; 0314 } 0315 0316 void writeResultToDevice(const QRect &rect, 0317 const int cacheRowStride, 0318 const int halfKernelWidth, 0319 const int halfKernelHeight, 0320 const FFTInfo &info, 0321 const QRect &dataRect) { 0322 0323 typename _IteratorFactory_::HLineIterator hitDst = 0324 _IteratorFactory_::createHLineIterator(this->m_painter->device(), 0325 rect.x(), rect.y(), rect.width(), 0326 dataRect); 0327 0328 int initialOffset = cacheRowStride * halfKernelHeight + halfKernelWidth; 0329 0330 const int channelCount = info.numChannels(); 0331 QVector<double*> channelPtr(channelCount); 0332 const auto channelPtrBegin = channelPtr.begin(); 0333 const auto channelPtrEnd = channelPtr.end(); 0334 0335 auto iFFt = m_channelFFT.constBegin(); 0336 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iFFt) { 0337 *i = (double*)*iFFt + initialOffset; 0338 } 0339 0340 // prepare cache, reused in all loops 0341 QVector<double*> cacheRowStart(channelCount); 0342 const auto cacheRowStartBegin = cacheRowStart.begin(); 0343 0344 for (int y = 0; y < rect.height(); ++y) { 0345 // cache current channelPtr in cacheRowStart 0346 memcpy(cacheRowStart.data(), channelPtr.data(), channelCount * sizeof(double*)); 0347 0348 for (int x = 0; x < rect.width(); ++x) { 0349 quint8 *dstPtr = hitDst->rawData(); 0350 0351 if (info.alphaCachePos >= 0) { 0352 bool alphaIsNullInDstSpace = false; 0353 0354 qreal alphaValue = 0355 writeAlphaFromCache(dstPtr, 0356 info.alphaCachePos, 0357 info, 0358 channelPtr.at(info.alphaCachePos), 0359 &alphaIsNullInDstSpace); 0360 0361 if (!alphaIsNullInDstSpace && 0362 alphaValue > std::numeric_limits<qreal>::epsilon()) { 0363 0364 qreal alphaValueInv = 1.0 / alphaValue; 0365 0366 int k = 0; 0367 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) { 0368 if (k != info.alphaCachePos) { 0369 writeOneChannelFromCache<true>(dstPtr, 0370 k, 0371 info, 0372 *i, 0373 alphaValueInv); 0374 } 0375 ++(*i); 0376 } 0377 } else { 0378 int k = 0; 0379 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) { 0380 if (k != info.alphaCachePos) { 0381 info.fromDoubleFuncPtr[k](dstPtr, 0382 info.convChannelList[k]->pos(), 0383 0.0); 0384 } 0385 ++(*i); 0386 } 0387 } 0388 } else { 0389 int k = 0; 0390 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++k) { 0391 writeOneChannelFromCache<false>(dstPtr, 0392 k, 0393 info, 0394 *i); 0395 ++(*i); 0396 } 0397 } 0398 0399 hitDst->nextPixel(); 0400 } 0401 0402 auto iRowStart = cacheRowStartBegin; 0403 for (auto i = channelPtrBegin; i != channelPtrEnd; ++i, ++iRowStart) { 0404 *i = *iRowStart + cacheRowStride; 0405 } 0406 0407 hitDst->nextRow(); 0408 } 0409 0410 } 0411 0412 private: 0413 void fftFillKernelMatrix(const KisConvolutionKernelSP kernel, fftw_complex *m_kernelFFT) 0414 { 0415 // find central item 0416 QPoint offset((kernel->width() - 1) / 2, (kernel->height() - 1) / 2); 0417 0418 qint32 xShift = m_fftWidth - offset.x(); 0419 qint32 yShift = m_fftHeight - offset.y(); 0420 0421 quint32 absXpos, absYpos; 0422 0423 for (quint32 y = 0; y < kernel->height(); y++) 0424 { 0425 absYpos = y + yShift; 0426 if (absYpos >= m_fftHeight) 0427 absYpos -= m_fftHeight; 0428 0429 for (quint32 x = 0; x < kernel->width(); x++) 0430 { 0431 absXpos = x + xShift; 0432 if (absXpos >= m_fftWidth) 0433 absXpos -= m_fftWidth; 0434 0435 ((double*)m_kernelFFT)[(m_fftWidth + m_extraMem) * absYpos + absXpos] = kernel->data()->coeff(y, x); 0436 } 0437 } 0438 } 0439 0440 void fftMultiply(fftw_complex* channel, fftw_complex* kernel) 0441 { 0442 // perform complex multiplication 0443 fftw_complex *channelPtr = channel; 0444 fftw_complex *kernelPtr = kernel; 0445 0446 fftw_complex tmp; 0447 0448 for (quint32 pixelPos = 0; pixelPos < m_fftLength; ++pixelPos) 0449 { 0450 tmp[0] = ((*channelPtr)[0] * (*kernelPtr)[0]) - ((*channelPtr)[1] * (*kernelPtr)[1]); 0451 tmp[1] = ((*channelPtr)[0] * (*kernelPtr)[1]) + ((*channelPtr)[1] * (*kernelPtr)[0]); 0452 0453 (*channelPtr)[0] = tmp[0]; 0454 (*channelPtr)[1] = tmp[1]; 0455 0456 ++channelPtr; 0457 ++kernelPtr; 0458 } 0459 } 0460 0461 void optimumDimensions(quint32& w, quint32& h) 0462 { 0463 // FFTW is most efficient when array size is a factor of 2, 3, 5 or 7 0464 quint32 optW = w, optH = h; 0465 0466 while ((optW % 2 != 0) || (optW % 3 != 0) || (optW % 5 != 0) || (optW % 7 != 0)) 0467 ++optW; 0468 0469 while ((optH % 2 != 0) || (optH % 3 != 0) || (optH % 5 != 0) || (optH % 7 != 0)) 0470 ++optH; 0471 0472 quint32 optAreaW = optW * h; 0473 quint32 optAreaH = optH * w; 0474 0475 if (optAreaW < optAreaH) { 0476 w = optW; 0477 } 0478 else { 0479 h = optH; 0480 } 0481 } 0482 0483 void fftLogMatrix(double* channel, const QString &f) 0484 { 0485 KisConvolutionWorkerFFTLock::fftwMutex.lock(); 0486 QString filename(QDir::homePath() + "/log_" + f + ".txt"); 0487 dbgKrita << "Log File Name: " << filename; 0488 QFile file (filename); 0489 if (!file.open(QIODevice::WriteOnly | QIODevice::Truncate | QIODevice::Text)) 0490 { 0491 dbgKrita << "Failed"; 0492 KisConvolutionWorkerFFTLock::fftwMutex.unlock(); 0493 return; 0494 } 0495 0496 QTextStream in(&file); 0497 in.setCodec("UTF-8"); 0498 for (quint32 y = 0; y < m_fftHeight; y++) 0499 { 0500 for (quint32 x = 0; x < m_fftWidth; x++) 0501 { 0502 QString num = QString::number(channel[y * m_fftWidth + x]); 0503 while (num.length() < 15) 0504 num += " "; 0505 0506 in << num << " "; 0507 } 0508 in << "\n"; 0509 } 0510 KisConvolutionWorkerFFTLock::fftwMutex.unlock(); 0511 } 0512 0513 void addToProgress(float amount) 0514 { 0515 m_currentProgress += amount; 0516 0517 if (this->m_progress) { 0518 this->m_progress->setProgress((int)m_currentProgress); 0519 } 0520 } 0521 0522 bool isInterrupted() 0523 { 0524 if (this->m_progress && this->m_progress->interrupted()) { 0525 cleanUp(); 0526 return true; 0527 } 0528 0529 return false; 0530 } 0531 0532 void cleanUp() 0533 { 0534 // free kernel fft data 0535 if (m_kernelFFT) { 0536 fftw_free(m_kernelFFT); 0537 } 0538 0539 Q_FOREACH (fftw_complex *channel, m_channelFFT) { 0540 fftw_free(channel); 0541 } 0542 m_channelFFT.clear(); 0543 } 0544 private: 0545 quint32 m_fftWidth {0}; 0546 quint32 m_fftHeight {0}; 0547 quint32 m_fftLength {0}; 0548 quint32 m_extraMem {0}; 0549 float m_currentProgress {0.0}; 0550 0551 fftw_complex* m_kernelFFT {0}; 0552 QVector<fftw_complex*> m_channelFFT; 0553 }; 0554 0555 #endif