File indexing completed on 2025-02-23 05:15:17

0001 //
0002 // Copyright (C) 2004-2006 Maciej Sobczak, Stephen Hutton
0003 // MySQL backend copyright (C) 2006 Pawel Aleksander Fedorynski
0004 // Distributed under the Boost Software License, Version 1.0.
0005 // (See accompanying file LICENSE_1_0.txt or copy at
0006 // http://www.boost.org/LICENSE_1_0.txt)
0007 //
0008 
0009 #define SOCI_MYSQL_SOURCE
0010 #include "soci/mysql/soci-mysql.h"
0011 #include "soci/connection-parameters.h"
0012 // std
0013 #include <cctype>
0014 #include <cerrno>
0015 #include <ciso646>
0016 #include <climits>
0017 #include <cstdio>
0018 #include <cstdlib>
0019 #include <string>
0020 
0021 #ifdef _MSC_VER
0022 #pragma warning(disable:4355)
0023 #endif
0024 
0025 using namespace soci;
0026 using namespace soci::details;
0027 using std::string;
0028 
0029 
0030 namespace
0031 { // anonymous
0032 
0033 // Helper class used to ensure we call mysql_library_init() before opening the
0034 // first MySQL connection and mysql_library_end() on application shutdown.
0035 class mysql_library
0036 {
0037 private:
0038     mysql_library()
0039     {
0040         if (mysql_library_init(0, NULL, NULL))
0041         {
0042             throw soci_error("Failed to initialize MySQL library.");
0043         }
0044     }
0045 
0046     ~mysql_library()
0047     {
0048         mysql_library_end();
0049     }
0050 
0051 public:
0052     // This function only exists for its side effect of creating an instance of
0053     // this class.
0054     static void ensure_initialized()
0055     {
0056         static mysql_library ins;
0057     }
0058 };
0059 
0060 void skip_white(std::string::const_iterator *i,
0061     std::string::const_iterator const & end, bool endok)
0062 {
0063     for (;;)
0064     {
0065         if (*i == end)
0066         {
0067             if (endok)
0068             {
0069                 return;
0070             }
0071             else
0072             {
0073                 throw soci_error("Unexpected end of connection string.");
0074             }
0075         }
0076         if (std::isspace(**i))
0077         {
0078             ++*i;
0079         }
0080         else
0081         {
0082             return;
0083         }
0084     }
0085 }
0086 
0087 std::string param_name(std::string::const_iterator *i,
0088     std::string::const_iterator const & end)
0089 {
0090     std::string val("");
0091     for (;;)
0092     {
0093         if (*i == end || (!std::isalpha(**i) && **i != '_'))
0094         {
0095             break;
0096         }
0097         val += **i;
0098         ++*i;
0099     }
0100     return val;
0101 }
0102 
0103 string param_value(string::const_iterator *i,
0104     string::const_iterator const & end)
0105 {
0106     string err = "Malformed connection string.";
0107     bool quot;
0108     if (**i == '\'')
0109     {
0110         quot = true;
0111         ++*i;
0112     }
0113     else
0114     {
0115         quot = false;
0116     }
0117     string val("");
0118     for (;;)
0119     {
0120         if (*i == end)
0121         {
0122             if (quot)
0123             {
0124                 throw soci_error(err);
0125             }
0126             else
0127             {
0128                 break;
0129             }
0130         }
0131         if (**i == '\'')
0132         {
0133             if (quot)
0134             {
0135                 ++*i;
0136                 break;
0137             }
0138             else
0139             {
0140                 throw soci_error(err);
0141             }
0142         }
0143         if (!quot && std::isspace(**i))
0144         {
0145             break;
0146         }
0147         if (**i == '\\')
0148         {
0149             ++*i;
0150             if (*i == end)
0151             {
0152                 throw soci_error(err);
0153             }
0154         }
0155         val += **i;
0156         ++*i;
0157     }
0158     return val;
0159 }
0160 
0161 bool valid_int(const string & s)
0162 {
0163     char *tail;
0164     const char *cstr = s.c_str();
0165     errno = 0;
0166     long n = std::strtol(cstr, &tail, 10);
0167     if (errno != 0 || n > INT_MAX || n < INT_MIN)
0168     {
0169         return false;
0170     }
0171     if (*tail != '\0')
0172     {
0173         return false;
0174     }
0175     return true;
0176 }
0177 
0178 bool valid_uint(const string & s)
0179 {
0180     char *tail;
0181     const char *cstr = s.c_str();
0182     errno = 0;
0183     unsigned long n = std::strtoul(cstr, &tail, 10);
0184     if (errno != 0 || n == 0 || n > UINT_MAX)
0185         return false;
0186     if (*tail != '\0')
0187         return false;
0188     return true;
0189 }
0190 
0191 void parse_connect_string(const string & connectString,
0192     string *host, bool *host_p,
0193     string *user, bool *user_p,
0194     string *password, bool *password_p,
0195     string *db, bool *db_p,
0196     string *unix_socket, bool *unix_socket_p,
0197     int *port, bool *port_p, string *ssl_ca, bool *ssl_ca_p,
0198     string *ssl_cert, bool *ssl_cert_p, string *ssl_key, bool *ssl_key_p,
0199     int *local_infile, bool *local_infile_p,
0200     string *charset, bool *charset_p, bool *reconnect_p,
0201     unsigned int *connect_timeout, bool *connect_timeout_p,
0202     unsigned int *read_timeout, bool *read_timeout_p,
0203     unsigned int *write_timeout, bool *write_timeout_p)
0204 {
0205     *host_p = false;
0206     *user_p = false;
0207     *password_p = false;
0208     *db_p = false;
0209     *unix_socket_p = false;
0210     *port_p = false;
0211     *ssl_ca_p = false;
0212     *ssl_cert_p = false;
0213     *ssl_key_p = false;
0214     *local_infile_p = false;
0215     *charset_p = false;
0216     *reconnect_p = false;
0217     *connect_timeout_p = false;
0218     *read_timeout_p = false;
0219     *write_timeout_p = false;
0220     string err = "Malformed connection string.";
0221     string::const_iterator i = connectString.begin(),
0222         end = connectString.end();
0223     while (i != end)
0224     {
0225         skip_white(&i, end, true);
0226         if (i == end)
0227         {
0228             return;
0229         }
0230         string par = param_name(&i, end);
0231         skip_white(&i, end, false);
0232         if (*i == '=')
0233         {
0234             ++i;
0235         }
0236         else
0237         {
0238             throw soci_error(err);
0239         }
0240         skip_white(&i, end, false);
0241         string val = param_value(&i, end);
0242         if (par == "port" && !*port_p)
0243         {
0244             if (!valid_int(val))
0245             {
0246                 throw soci_error(err);
0247             }
0248             *port = std::atoi(val.c_str());
0249             if (*port < 0)
0250             {
0251                 throw soci_error(err);
0252             }
0253             *port_p = true;
0254         }
0255         else if (par == "host" && !*host_p)
0256         {
0257             *host = val;
0258             *host_p = true;
0259         }
0260         else if (par == "user" && !*user_p)
0261         {
0262             *user = val;
0263             *user_p = true;
0264         }
0265         else if ((par == "pass" || par == "password") && !*password_p)
0266         {
0267             *password = val;
0268             *password_p = true;
0269         }
0270         else if ((par == "db" || par == "dbname" || par == "service") and !*db_p)
0271         {
0272             *db = val;
0273             *db_p = true;
0274         }
0275         else if (par == "unix_socket" && !*unix_socket_p)
0276         {
0277             *unix_socket = val;
0278             *unix_socket_p = true;
0279         }
0280         else if (par == "sslca" && !*ssl_ca_p)
0281         {
0282             *ssl_ca = val;
0283             *ssl_ca_p = true;
0284         }
0285         else if (par == "sslcert" && !*ssl_cert_p)
0286         {
0287             *ssl_cert = val;
0288             *ssl_cert_p = true;
0289         }
0290         else if (par == "sslkey" && !*ssl_key_p)
0291         {
0292             *ssl_key = val;
0293             *ssl_key_p = true;
0294         }
0295         else if (par == "local_infile" && !*local_infile_p)
0296         {
0297             if (!valid_int(val))
0298             {
0299                 throw soci_error(err);
0300             }
0301             *local_infile = std::atoi(val.c_str());
0302             if (*local_infile != 0 && *local_infile != 1)
0303             {
0304                 throw soci_error(err);
0305             }
0306             *local_infile_p = true;
0307         } else if (par == "charset" && !*charset_p)
0308         {
0309             *charset = val;
0310             *charset_p = true;
0311         } else if (par == "reconnect" && !*reconnect_p)
0312         {
0313             if (val != "1")
0314                 throw soci_error("\"reconnect\" option may only be set to 1");
0315 
0316             *reconnect_p = true;
0317         } else if (par == "connect_timeout" && !*connect_timeout_p)
0318         {
0319             if (!valid_uint(val))
0320                 throw soci_error(err);
0321             char *endp;
0322             *connect_timeout = std::strtoul(val.c_str(), &endp, 10);
0323             *connect_timeout_p = true;
0324         } else if (par == "read_timeout" && !*read_timeout_p)
0325         {
0326             if (!valid_uint(val))
0327                 throw soci_error(err);
0328             char *endp;
0329             *read_timeout = std::strtoul(val.c_str(), &endp, 10);
0330             *read_timeout_p = true;
0331         } else if (par == "write_timeout" && !*write_timeout_p)
0332         {
0333             if (!valid_uint(val))
0334                 throw soci_error(err);
0335             char *endp;
0336             *write_timeout = std::strtoul(val.c_str(), &endp, 10);
0337             *write_timeout_p = true;
0338         }
0339         else
0340         {
0341             throw soci_error(err);
0342         }
0343     }
0344 }
0345 
0346 } // namespace anonymous
0347 
0348 
0349 #ifdef __clang__
0350 #pragma clang diagnostic push
0351 #pragma clang diagnostic ignored "-Wuninitialized"
0352 #endif
0353 
0354 #if defined(__GNUC__) && ( __GNUC__ > 4 || (__GNUC__ == 4 && (__GNUC_MINOR__ > 6)))
0355 #pragma GCC diagnostic push
0356 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
0357 #endif
0358 
0359 
0360 mysql_session_backend::mysql_session_backend(
0361     connection_parameters const & parameters)
0362 {
0363     mysql_library::ensure_initialized();
0364 
0365     string host, user, password, db, unix_socket, ssl_ca, ssl_cert, ssl_key,
0366         charset;
0367     int port, local_infile;
0368     unsigned int connect_timeout, read_timeout, write_timeout;
0369     bool host_p, user_p, password_p, db_p, unix_socket_p, port_p,
0370         ssl_ca_p, ssl_cert_p, ssl_key_p, local_infile_p, charset_p, reconnect_p,
0371         connect_timeout_p, read_timeout_p, write_timeout_p;
0372     parse_connect_string(parameters.get_connect_string(), &host, &host_p, &user, &user_p,
0373         &password, &password_p, &db, &db_p,
0374         &unix_socket, &unix_socket_p, &port, &port_p,
0375         &ssl_ca, &ssl_ca_p, &ssl_cert, &ssl_cert_p, &ssl_key, &ssl_key_p,
0376         &local_infile, &local_infile_p, &charset, &charset_p, &reconnect_p,
0377         &connect_timeout, &connect_timeout_p,
0378         &read_timeout, &read_timeout_p,
0379         &write_timeout, &write_timeout_p);
0380     conn_ = mysql_init(NULL);
0381     if (conn_ == NULL)
0382     {
0383         throw soci_error("mysql_init() failed.");
0384     }
0385     if (reconnect_p)
0386     {
0387         #if MYSQL_VERSION_ID < 8
0388             my_bool reconnect = 1;
0389         #else
0390             bool reconnect = 1;
0391         #endif
0392         if (0 != mysql_options(conn_, MYSQL_OPT_RECONNECT, &reconnect))
0393         {
0394             clean_up();
0395             throw soci_error("mysql_options(MYSQL_OPT_RECONNECT) failed.");
0396         }
0397     }
0398     if (charset_p)
0399     {
0400         if (0 != mysql_options(conn_, MYSQL_SET_CHARSET_NAME, charset.c_str()))
0401         {
0402             clean_up();
0403             throw soci_error("mysql_options(MYSQL_SET_CHARSET_NAME) failed.");
0404         }
0405     }
0406     if (ssl_ca_p)
0407     {
0408         mysql_ssl_set(conn_, ssl_key_p ? ssl_key.c_str() : NULL,
0409                       ssl_cert_p ? ssl_cert.c_str() : NULL,
0410                       ssl_ca.c_str(), 0, 0);
0411     }
0412     if (local_infile_p && local_infile == 1)
0413     {
0414         if (0 != mysql_options(conn_, MYSQL_OPT_LOCAL_INFILE, NULL))
0415         {
0416             clean_up();
0417             throw soci_error(
0418                 "mysql_options() failed when trying to set local-infile.");
0419         }
0420     }
0421     if (connect_timeout_p)
0422     {
0423         if (0 != mysql_options(conn_, MYSQL_OPT_CONNECT_TIMEOUT, &connect_timeout))
0424         {
0425             clean_up();
0426             throw soci_error("mysql_options(MYSQL_OPT_CONNECT_TIMEOUT) failed.");
0427         }
0428     }
0429     if (read_timeout_p)
0430     {
0431         if (0 != mysql_options(conn_, MYSQL_OPT_READ_TIMEOUT, &read_timeout))
0432         {
0433             clean_up();
0434             throw soci_error("mysql_options(MYSQL_OPT_READ_TIMEOUT) failed.");
0435         }
0436     }
0437     if (write_timeout_p)
0438     {
0439         if (0 != mysql_options(conn_, MYSQL_OPT_WRITE_TIMEOUT, &write_timeout))
0440         {
0441             clean_up();
0442             throw soci_error("mysql_options(MYSQL_OPT_WRITE_TIMEOUT) failed.");
0443         }
0444     }
0445     if (mysql_real_connect(conn_,
0446             host_p ? host.c_str() : NULL,
0447             user_p ? user.c_str() : NULL,
0448             password_p ? password.c_str() : NULL,
0449             db_p ? db.c_str() : NULL,
0450             port_p ? port : 0,
0451             unix_socket_p ? unix_socket.c_str() : NULL,
0452 #ifdef CLIENT_MULTI_RESULTS
0453             CLIENT_FOUND_ROWS | CLIENT_MULTI_RESULTS) == NULL)
0454 #else
0455             CLIENT_FOUND_ROWS) == NULL)
0456 #endif
0457     {
0458         string errMsg = mysql_error(conn_);
0459         unsigned int errNum = mysql_errno(conn_);
0460         clean_up();
0461         throw mysql_soci_error(errMsg, errNum);
0462     }
0463 }
0464 
0465 #if defined(__GNUC__) && ( __GNUC__ > 4 || (__GNUC__ == 4 && (__GNUC_MINOR__ > 6)))
0466 #pragma GCC diagnostic pop
0467 #endif
0468 
0469 #ifdef __clang__
0470 #pragma clang diagnostic pop
0471 #endif
0472 
0473 
0474 
0475 mysql_session_backend::~mysql_session_backend()
0476 {
0477     clean_up();
0478 }
0479 
0480 namespace // unnamed
0481 {
0482 
0483 // helper function for hardcoded queries
0484 void hard_exec(MYSQL *conn, const string & query)
0485 {
0486     if (0 != mysql_real_query(conn, query.c_str(),
0487             static_cast<unsigned long>(query.size())))
0488     {
0489         //throw soci_error(mysql_error(conn));
0490         string errMsg = mysql_error(conn);
0491         unsigned int errNum = mysql_errno(conn);
0492         throw mysql_soci_error(errMsg, errNum);
0493 
0494     }
0495 }
0496 
0497 } // namespace unnamed
0498 
0499 bool mysql_session_backend::is_connected()
0500 {
0501     return mysql_ping(conn_) == 0;
0502 }
0503 
0504 void mysql_session_backend::begin()
0505 {
0506     hard_exec(conn_, "BEGIN");
0507 }
0508 
0509 void mysql_session_backend::commit()
0510 {
0511     hard_exec(conn_, "COMMIT");
0512 }
0513 
0514 void mysql_session_backend::rollback()
0515 {
0516     hard_exec(conn_, "ROLLBACK");
0517 }
0518 
0519 bool mysql_session_backend::get_last_insert_id(
0520     session & /* s */, std::string const & /* table */, long long & value)
0521 {
0522     value = static_cast<long long>(mysql_insert_id(conn_));
0523 
0524     return true;
0525 }
0526 
0527 void mysql_session_backend::clean_up()
0528 {
0529     if (conn_ != NULL)
0530     {
0531         mysql_close(conn_);
0532         conn_ = NULL;
0533     }
0534 }
0535 
0536 mysql_statement_backend * mysql_session_backend::make_statement_backend()
0537 {
0538     return new mysql_statement_backend(*this);
0539 }
0540 
0541 mysql_rowid_backend * mysql_session_backend::make_rowid_backend()
0542 {
0543     return new mysql_rowid_backend(*this);
0544 }
0545 
0546 mysql_blob_backend * mysql_session_backend::make_blob_backend()
0547 {
0548     return new mysql_blob_backend(*this);
0549 }