File indexing completed on 2025-02-23 05:15:21

0001 //
0002 // Copyright (C) 2004-2006 Maciej Sobczak, Stephen Hutton, David Courtney
0003 // Distributed under the Boost Software License, Version 1.0.
0004 // (See accompanying file LICENSE_1_0.txt or copy at
0005 // http://www.boost.org/LICENSE_1_0.txt)
0006 //
0007 
0008 #define SOCI_SQLITE3_SOURCE
0009 #include "soci/sqlite3/soci-sqlite3.h"
0010 
0011 #include "soci/connection-parameters.h"
0012 
0013 #include "soci-cstrtoi.h"
0014 
0015 #include <functional>
0016 #include <sstream>
0017 #include <string>
0018 
0019 #ifdef _MSC_VER
0020 #pragma warning(disable:4355)
0021 #endif
0022 
0023 using namespace soci;
0024 using namespace soci::details;
0025 using namespace sqlite_api;
0026 
0027 namespace // anonymous
0028 {
0029 
0030 // Callback function use to construct the error message in the provided stream.
0031 //
0032 // SQLite3 own error message will be appended to it.
0033 using error_callback = std::function<void (std::ostream& ostr)>;
0034 
0035 // helper function for hardcoded queries: this is a simple wrapper for
0036 // sqlite3_exec() which throws an exception on error.
0037 void execude_hardcoded(sqlite_api::sqlite3* conn, char const* const query,
0038                        error_callback const& errCallback,
0039                        int (*callback)(void*, int, char**, char**) = NULL,
0040                        void* callback_arg = NULL)
0041 {
0042     char *zErrMsg = 0;
0043     int const res = sqlite3_exec(conn, query, callback, callback_arg, &zErrMsg);
0044     if (res != SQLITE_OK)
0045     {
0046         std::ostringstream ss;
0047         errCallback(ss);
0048         ss << ": " << zErrMsg;
0049         sqlite3_free(zErrMsg);
0050         throw sqlite3_soci_error(ss.str(), res);
0051     }
0052 }
0053 
0054 // Simpler to use overload which uses a hard coded error message.
0055 void execude_hardcoded(sqlite_api::sqlite3* conn, char const* const query, char const* const errMsg,
0056                        int (*callback)(void*, int, char**, char**) = NULL,
0057                        void* callback_arg = NULL)
0058 {
0059     return execude_hardcoded(conn, query,
0060         [errMsg](std::ostream& ostr) { ostr << errMsg; },
0061         callback, callback_arg
0062     );
0063 }
0064 
0065 void check_sqlite_err(sqlite_api::sqlite3* conn, int res,
0066                       error_callback const& errCallback)
0067 {
0068     if (SQLITE_OK != res)
0069     {
0070         const char *zErrMsg = sqlite3_errmsg(conn);
0071         std::ostringstream ss;
0072         errCallback(ss);
0073         ss << ": " << zErrMsg;
0074         sqlite3_close(conn); // connection must be closed here
0075         throw sqlite3_soci_error(ss.str(), res);
0076     }
0077 }
0078 
0079 void check_sqlite_err(sqlite_api::sqlite3* conn, int res, char const* const errMsg)
0080 {
0081     return check_sqlite_err(conn, res, [errMsg](std::ostream& ostr) { ostr << errMsg; });
0082 }
0083 
0084 } // namespace anonymous
0085 
0086 static int sequence_table_exists_callback(void* ctxt, int result_columns, char**, char**)
0087 {
0088     bool* const flag = static_cast<bool*>(ctxt);
0089     *flag = result_columns > 0;
0090     return 0;
0091 }
0092 
0093 static bool check_if_sequence_table_exists(sqlite_api::sqlite3* conn)
0094 {
0095     bool sequence_table_exists = false;
0096     execude_hardcoded
0097     (
0098       conn,
0099       "select name from sqlite_master where type='table' and name='sqlite_sequence'",
0100       "Failed checking if the sqlite_sequence table exists",
0101       &sequence_table_exists_callback,
0102       &sequence_table_exists
0103     );
0104 
0105     return sequence_table_exists;
0106 }
0107 
0108 sqlite3_session_backend::sqlite3_session_backend(
0109     connection_parameters const & parameters)
0110     : sequence_table_exists_(false)
0111 {
0112     int timeout = 0;
0113     int connection_flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE;
0114     std::string vfs;
0115     std::string synchronous;
0116     std::string foreignKeys;
0117     std::string const & connectString = parameters.get_connect_string();
0118     std::string dbname(connectString);
0119     std::stringstream ssconn(connectString);
0120     while (!ssconn.eof() && ssconn.str().find('=') != std::string::npos)
0121     {
0122         std::string key, val;
0123         std::getline(ssconn, key, '=');
0124         std::getline(ssconn, val, ' ');
0125 
0126         if (val.size()>0 && val[0]=='\"')
0127         {
0128             std::string quotedVal = val.erase(0, 1);
0129 
0130             if (quotedVal[quotedVal.size()-1] ==  '\"')
0131             {
0132                 quotedVal.erase(val.size()-1);
0133             }
0134             else // space inside value string
0135             {
0136                 std::getline(ssconn, val, '\"');
0137                 quotedVal = quotedVal + " " + val;
0138                 std::string keepspace;
0139                 std::getline(ssconn, keepspace, ' ');
0140             }
0141 
0142             val = quotedVal;
0143         }
0144 
0145         if ("dbname" == key || "db" == key)
0146         {
0147             dbname = val;
0148         }
0149         else if ("timeout" == key)
0150         {
0151             std::istringstream converter(val);
0152             converter >> timeout;
0153         }
0154         else if ("synchronous" == key)
0155         {
0156             synchronous = val;
0157         }
0158         else if ("readonly" == key)
0159         {
0160             connection_flags = (connection_flags | SQLITE_OPEN_READONLY) & ~(SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE);
0161         }
0162         else if ("nocreate" == key)
0163         {
0164             connection_flags &= ~SQLITE_OPEN_CREATE;
0165         }
0166         else if ("shared_cache" == key && "true" == val)
0167         {
0168             connection_flags |= SQLITE_OPEN_SHAREDCACHE;
0169         }
0170         else if ("vfs" == key)
0171         {
0172             vfs = val;
0173         }
0174         else if ("foreign_keys" == key)
0175         {
0176             foreignKeys = val;
0177         }
0178     }
0179 
0180     int res = sqlite3_open_v2(dbname.c_str(), &conn_, connection_flags, (vfs.empty()?NULL:vfs.c_str()));
0181     check_sqlite_err(conn_, res,
0182         [&dbname](std::ostream& ostr)
0183         {
0184             ostr << "Cannot establish connection to \"" << dbname << "\"";
0185         }
0186     );
0187 
0188     if (!synchronous.empty())
0189     {
0190         std::string const query("pragma synchronous=" + synchronous);
0191         execude_hardcoded(conn_, query.c_str(),
0192             [&synchronous](std::ostream& ostr)
0193             {
0194                 ostr << "Setting synchronous pragma to \"" << synchronous << "\" failed";
0195             }
0196         );
0197     }
0198 
0199     if (!foreignKeys.empty())
0200     {
0201         std::string const query("pragma foreign_keys=" + foreignKeys);
0202         execude_hardcoded(conn_, query.c_str(),
0203             [&foreignKeys](std::ostream& ostr)
0204             {
0205                 ostr << "Setting foreign_keys pragma to \"" << foreignKeys << "\" failed";
0206             }
0207         );
0208     }
0209 
0210     res = sqlite3_busy_timeout(conn_, timeout * 1000);
0211     check_sqlite_err(conn_, res, "Failed to set busy timeout for connection. ");
0212 }
0213 
0214 sqlite3_session_backend::~sqlite3_session_backend()
0215 {
0216     clean_up();
0217 }
0218 
0219 void sqlite3_session_backend::begin()
0220 {
0221     execude_hardcoded(conn_, "BEGIN", "Cannot begin transaction.");
0222 }
0223 
0224 void sqlite3_session_backend::commit()
0225 {
0226     execude_hardcoded(conn_, "COMMIT", "Cannot commit transaction.");
0227 }
0228 
0229 void sqlite3_session_backend::rollback()
0230 {
0231     execude_hardcoded(conn_, "ROLLBACK", "Cannot rollback transaction.");
0232 }
0233 
0234 // Argument passed to store_single_value_callback(), which is used to retrieve
0235 // a single numeric value from a hardcoded query.
0236 struct single_value_callback_ctx
0237 {
0238     single_value_callback_ctx() : valid_(false) {}
0239 
0240     long long value_;
0241     bool valid_;
0242 };
0243 
0244 static int store_single_value_callback(void* ctx, int result_columns, char** values, char**)
0245 {
0246     single_value_callback_ctx* arg = static_cast<single_value_callback_ctx*>(ctx);
0247 
0248     if (result_columns == 1 && values[0])
0249     {
0250         if (cstring_to_integer(arg->value_, values[0]))
0251             arg->valid_ = true;
0252     }
0253 
0254     return 0;
0255 }
0256 
0257 static std::string sanitize_table_name(std::string const& table)
0258 {
0259     std::string ret;
0260     ret.reserve(table.length());
0261     for (std::string::size_type pos = 0; pos < table.size(); ++pos)
0262     {
0263         if (isspace(table[pos]))
0264             throw sqlite3_soci_error("Table name must not contain whitespace", 0);
0265         const char c = table[pos];
0266         ret += c;
0267         if (c == '\'')
0268             ret += '\'';
0269         else if (c == '\"')
0270             ret += '\"';
0271     }
0272     return ret;
0273 }
0274 
0275 bool sqlite3_session_backend::get_last_insert_id(
0276     session &, std::string const & table, long long & value)
0277 {
0278     single_value_callback_ctx ctx;
0279     if (sequence_table_exists_ || check_if_sequence_table_exists(conn_))
0280     {
0281         // Once the sqlite_sequence table is created (because of a column marked AUTOINCREMENT)
0282         // it can never be dropped, so don't search for it again.
0283         sequence_table_exists_ = true;
0284 
0285         std::string const query =
0286             "select seq from sqlite_sequence where name ='" + sanitize_table_name(table) + "'";
0287         execude_hardcoded(conn_, query.c_str(),  "Unable to get value in sqlite_sequence",
0288                           &store_single_value_callback, &ctx);
0289 
0290         // The value will not be filled if the callback was never called.
0291         // It may mean either that nothing was inserted yet into the table
0292         // that has an AUTOINCREMENT column, or that the table does not have an AUTOINCREMENT
0293         // column.
0294         if (ctx.valid_)
0295         {
0296             value = ctx.value_;
0297             return true;
0298         }
0299     }
0300 
0301     // Fall-back just get the maximum rowid of what was already inserted in the
0302     // table. This has the disadvantage that if rows were deleted, then ids may be re-used.
0303     // But, if one cares about that, AUTOINCREMENT should be used anyway.
0304     std::string const max_rowid_query = "select max(rowid) from \"" + sanitize_table_name(table) + "\"";
0305     execude_hardcoded(conn_, max_rowid_query.c_str(),  "Unable to get max rowid",
0306                       &store_single_value_callback, &ctx);
0307     value = ctx.valid_ ? ctx.value_ : 0;
0308 
0309     return true;
0310 }
0311 
0312 void sqlite3_session_backend::clean_up()
0313 {
0314     sqlite3_close(conn_);
0315 }
0316 
0317 sqlite3_statement_backend * sqlite3_session_backend::make_statement_backend()
0318 {
0319     return new sqlite3_statement_backend(*this);
0320 }
0321 
0322 sqlite3_rowid_backend * sqlite3_session_backend::make_rowid_backend()
0323 {
0324     return new sqlite3_rowid_backend(*this);
0325 }
0326 
0327 sqlite3_blob_backend * sqlite3_session_backend::make_blob_backend()
0328 {
0329     return new sqlite3_blob_backend(*this);
0330 }