File indexing completed on 2024-05-05 15:55:08

0001 #pragma once
0002 
0003 #include "../../auxiliary/robuststatistics.h"
0004 
0005 #include <QVector>
0006 #include <qcustomplot.h>
0007 #include <gsl/gsl_vector.h>
0008 #include <gsl/gsl_min.h>
0009 #include <gsl/gsl_matrix.h>
0010 #include <gsl/gsl_multifit.h>
0011 #include <gsl/gsl_multifit_nlinear.h>
0012 #include <gsl/gsl_blas.h>
0013 #include <ekos_focus_debug.h>
0014 
0015 namespace Ekos
0016 {
0017 // The curve fitting class provides curve fitting algorithms using the Lehvensberg-Marquart (LM)
0018 // solver as provided the Gnu Science Library (GSL). See the comments at the start of curvefit.cpp
0019 // for more details.
0020 //
0021 // Several curves are provided:
0022 // For lines: hyperbola, parabola
0023 // For surfaces: gaussian, plane
0024 // For compatibility with existing Ekos functionality a Quadratic option using the exising Ekos linear
0025 // least squares solver (again provided by GSL) is supported. The Quadratic and Parabola curves are
0026 // the same thing mathematically but Parabola uses the non-linear least squares LM solver whilst Quadratic
0027 // uses the original Ekos linear least squares solver.
0028 //
0029 // Users of CurveFitting operate on focuser position and HFR. Within CurveFitting the curve uses the more
0030 // usual mathematical notation of x, y
0031 //
0032 // Furture releases may migrate all curve fitting to the LM solver.
0033 class CurveFitting
0034 {
0035     public:
0036         typedef enum { FOCUS_QUADRATIC, FOCUS_HYPERBOLA, FOCUS_PARABOLA, FOCUS_GAUSSIAN, FOCUS_PLANE } CurveFit;
0037         typedef enum { OPTIMISATION_MINIMISE, OPTIMISATION_MAXIMISE } OptimisationDirection;
0038         typedef enum { STANDARD, BEST, BEST_RETRY } FittingGoal;
0039 
0040         // Data structures to hold datapoints for the class
0041         struct DataPT
0042         {
0043             double x;            // focuser position
0044             double y;            // focus measurement, e.g. HFR
0045             double weight;       // the measurement weight, e.g. inverse variance
0046         };
0047 
0048         struct DataPointT        // This is the data strcture passed into GSL LM routines
0049         {
0050             bool useWeights;     // Are we fitting the curve with or without weights
0051             QVector<DataPT> dps; // Vector of DataPT
0052             OptimisationDirection dir = OPTIMISATION_MINIMISE; // Are we minimising or maximising?
0053             Mathematics::RobustStatistics::ScaleCalculation weightCalculation =
0054                 Mathematics::RobustStatistics::SCALE_VARIANCE; // How to compute weights
0055 
0056             // Helper functions to operate on the data structure
0057             void push_back(double x, double y, double weight = 1)
0058             {
0059                 dps.push_back({x, y, weight});
0060             }
0061         };
0062 
0063         struct DataPT3D
0064         {
0065             double x;            // x position
0066             double y;            // y position
0067             double z;               // value
0068             double weight;       // the measurement weight, e.g. inverse variance
0069         };
0070 
0071         struct DataPoint3DT      // This is the data strcture passed into GSL LM routines
0072         {
0073             bool useWeights;     // Are we fitting the curve with or without weights
0074             QVector<DataPT3D> dps; // Vector of DataPT3D
0075             OptimisationDirection dir = OPTIMISATION_MAXIMISE; // Are we minimising or maximising?
0076             Mathematics::RobustStatistics::ScaleCalculation weightCalculation =
0077                 Mathematics::RobustStatistics::SCALE_VARIANCE; // How to compute weights
0078 
0079             // Helper function to operate on the data structure
0080             void push_back(double x, double y, double z, double weight = 1)
0081             {
0082                 dps.push_back({x, y, z, weight});
0083             }
0084         };
0085 
0086         // Data structure to hold star parameters
0087         struct StarParams
0088         {
0089             double background;
0090             double peak;
0091             double centroid_x;
0092             double centroid_y;
0093             double HFR;
0094             double theta;
0095             double FWHMx;
0096             double FWHMy;
0097             double FWHM;
0098         };
0099 
0100         // Constructor just initialises the object
0101         CurveFitting();
0102 
0103         // A constructor that reconstructs a previously built object.
0104         // Does not implement getting the original data points.
0105         CurveFitting(const QString &serialized);
0106 
0107         // fitCurve takes in the vectors with the position, hfr and weight (e.g. variance in HFR) values
0108         // along with the type of curve to use and whether or not to use weights in the calculation
0109         // It fits the curve and solves for the coefficients.
0110         void fitCurve(const FittingGoal goal, const QVector<int> &position, const QVector<double> &hfr,
0111                       const QVector<double> &weights, const QVector<bool> &outliers,
0112                       const CurveFit curveFit, const bool useWeights, const OptimisationDirection optDir);
0113 
0114         // fitCurve3D 3-Dimensional version of fitCurve
0115         // Data is passed in in imageBuffer - a 2D array of width x height
0116         // Approx star information is passed in to seed the LM solver initial parameters.
0117         // Start and end define the x,y coordinates of a box around the star, start is top left corner, end is bottom right
0118         template <typename T>
0119         void fitCurve3D(const T *imageBuffer, const int imageWidth, const QPair<int, int> start, const QPair<int, int> end,
0120                         const StarParams &starParams, const CurveFit curveFit, const bool useWeights)
0121         {
0122             if (imageBuffer == nullptr)
0123             {
0124                 qCDebug(KSTARS_EKOS_FOCUS) << QString("CurveFitting::fitCurve3D null image ptr");
0125                 m_FirstSolverRun = true;
0126                 return;
0127             }
0128 
0129             if (imageWidth <= 0)
0130             {
0131                 qCDebug(KSTARS_EKOS_FOCUS) << QString("CurveFitting::fitCurve3D imageWidth=%1").arg(imageWidth);
0132                 m_FirstSolverRun = true;
0133                 return;
0134             }
0135 
0136             if (useWeights)
0137             {
0138                 qCDebug(KSTARS_EKOS_FOCUS) << QString("CurveFitting::fitCurve3D called with useWeights. Not yet implemented");
0139                 m_FirstSolverRun = true;
0140                 return;
0141             }
0142 
0143             m_dataPoints.dps.clear();
0144             m_dataPoints.useWeights = useWeights;
0145 
0146             // Load up the data structures for the solver.
0147             // The pixel reference x, y refers to the top left corner of the pizel so add 0.5 to x and y to reference the
0148             // centre of the pixel.
0149             int width = end.first - start.first;
0150             int height = end.second - start.second;
0151 
0152             for (int j = 0; j < height; j++)
0153                 for (int i = 0; i < width; i++)
0154                     m_dataPoints.push_back(i + 0.5, j + 0.5, imageBuffer[start.first + i + ((start.second + j) * imageWidth)], 1.0);
0155 
0156             m_CurveType = curveFit;
0157             switch (m_CurveType)
0158             {
0159                 case FOCUS_GAUSSIAN :
0160                     m_coefficients = gaussian_fit(m_dataPoints, starParams);
0161                     break;
0162                 default :
0163                     // Something went wrong, log an error and reset state so solver starts from scratch if called again
0164                     qCDebug(KSTARS_EKOS_FOCUS) << QString("CurveFitting::fitCurve3D called with curveFit=%1").arg(curveFit);
0165                     m_FirstSolverRun = true;
0166                     return;
0167             }
0168             m_LastCoefficients = m_coefficients;
0169             m_LastCurveType    = m_CurveType;
0170             m_FirstSolverRun   = false;
0171         }
0172 
0173         // Alternative form of fitCurve3D used on non-image data
0174         void fitCurve3D(const DataPoint3DT data, const CurveFit curveFit);
0175 
0176         // Returns the minimum position and value in the pointers for the solved curve.
0177         // Returns false if the curve couldn't be solved
0178         bool findMinMax(double expected, double minPosition, double maxPosition, double *position, double *value, CurveFit curveFit,
0179                         const OptimisationDirection optDir);
0180         // getStarParams returns the star parameters for the solved star
0181         bool getStarParams(const CurveFit curveFit, StarParams *starParams);
0182 
0183         // getCurveParams gets the coefficients of a curve solve
0184         // setCurveParams sets the coefficients of a curve solve
0185         // using get and set returns the solver to its state as it was when get was called
0186         // This allows functions like "f" to be called to calculate curve values for a
0187         // prior curve fit solution.
0188         bool getCurveParams(const CurveFit curveType, QVector<double> &coefficients)
0189         {
0190             if (curveType != m_CurveType)
0191                 return false;
0192             coefficients = m_coefficients;
0193             return true;
0194         }
0195 
0196         bool setCurveParams(const CurveFit curveType, const QVector<double> coefficients)
0197         {
0198             if (curveType != m_CurveType)
0199                 return false;
0200             m_coefficients = coefficients;
0201             return true;
0202         }
0203 
0204         // f calculates the value of y for a given x using the appropriate curve algorithm
0205         double f(double x);
0206         // f calculates the value of z for a given x and y using the appropriate curve algorithm
0207         double f3D(double x, double y);
0208         // Calculates the R-squared which is a measure of how well the curve fits the datapoints
0209         double calculateR2(CurveFit curveFit);
0210         // Calculate the deltas of each datapoint from the curve
0211         void calculateCurveDeltas(CurveFit curveFit, std::vector<std::pair<int, double>> &curveDeltas);
0212 
0213         // Returns a QString which can be used to construct an identical object.
0214         // Does not implement getting the original data points.
0215         QString serialize() const;
0216 
0217     private:
0218         // TODO: This function will likely go when Linear and L1P merge to be closer.
0219         // Calculates the value of the polynomial at x. Params will be cast to a CurveFit*.
0220         static double curveFunction(double x, void *params);
0221 
0222         // TODO: This function will likely go when Linear and L1P merge to be closer.
0223         QVector<double> polynomial_fit(const double *const data_x, const double *const data_y, const int n, const int order);
0224 
0225         QVector<double> hyperbola_fit(FittingGoal goal, const QVector<double> data_x, const QVector<double> data_y,
0226                                       const QVector<double> weights, bool useWeights, const OptimisationDirection optDir);
0227         QVector<double> parabola_fit(FittingGoal goal, const QVector<double> data_x, const QVector<double> data_y,
0228                                      const QVector<double> data_weights,
0229                                      bool useWeights, const OptimisationDirection optDir);
0230         QVector<double> gaussian_fit(DataPoint3DT data, const StarParams &starParams);
0231         QVector<double> plane_fit(const DataPoint3DT data);
0232 
0233         bool minimumQuadratic(double expected, double minPosition, double maxPosition, double *position, double *value);
0234         bool minMaxHyperbola(double expected, double minPosition, double maxPosition, double *position, double *value,
0235                              const OptimisationDirection optDir);
0236         bool minMaxParabola(double expected, double minPosition, double maxPosition, double *position, double *value,
0237                             const OptimisationDirection optDir);
0238         bool getGaussianParams(StarParams *starParams);
0239 
0240         void hypMakeGuess(const int attempt, const QVector<double> inX, const QVector<double> inY,
0241                           const OptimisationDirection optDir, gsl_vector * guess);
0242         void hypSetupParams(FittingGoal goal, gsl_multifit_nlinear_parameters *params, int *numIters, double *xtol, double *gtol,
0243                             double *ftol);
0244 
0245         void parMakeGuess(const int attempt, const QVector<double> inX, const QVector<double> inY,
0246                           const OptimisationDirection optDir,
0247                           gsl_vector * guess);
0248         void parSetupParams(FittingGoal goal, gsl_multifit_nlinear_parameters *params, int *numIters, double *xtol, double *gtol,
0249                             double *ftol);
0250         void gauMakeGuess(const int attempt, const StarParams &starParams, gsl_vector * guess);
0251         void gauSetupParams(gsl_multifit_nlinear_parameters *params, int *numIters, double *xtol, double *gtol, double *ftol);
0252         void plaMakeGuess(const int attempt, gsl_vector * guess);
0253         void plaSetupParams(gsl_multifit_nlinear_parameters *params, int *numIters, double *xtol, double *gtol, double *ftol);
0254 
0255         // Get the reason code from the passed in info
0256         QString getLMReasonCode(int info);
0257 
0258         // Calculation engine for the R-squared which is a measure of how well the curve fits the datapoints
0259         double calcR2(const QVector<double> dataPoints, const QVector<double> curvePoints, const QVector<double> scale,
0260                       const bool useWeights);
0261 
0262         // Used in the QString constructor.
0263         bool recreateFromQString(const QString &serialized);
0264 
0265         // Type of curve
0266         CurveFit m_CurveType;
0267         // The data values.
0268         QVector<double> m_x, m_y, m_scale;
0269         // Use weights or not
0270         bool m_useWeights;
0271         DataPoint3DT m_dataPoints;
0272         // The solved parameters.
0273         QVector<double> m_coefficients;
0274         // State variables used by the LM solver. These variables provide a way of optimising the starting
0275         // point for the solver by using the solution found by the previous run providing the relevant
0276         // solver parameters are consistent between runs.
0277         bool m_FirstSolverRun;
0278         CurveFit m_LastCurveType;
0279         QVector<double> m_LastCoefficients;
0280 };
0281 
0282 } //namespace