diff --git a/ext/sqlite3/database.c b/ext/sqlite3/database.c index 8e833faf..621dc7aa 100644 --- a/ext/sqlite3/database.c +++ b/ext/sqlite3/database.c @@ -49,15 +49,16 @@ discard_db(sqlite3RubyPtr ctx) } ctx->db = NULL; + ctx->flags |= SQLITE3_RB_DATABASE_DISCARDED; } static void close_or_discard_db(sqlite3RubyPtr ctx) { if (ctx->db) { - int isReadonly = (ctx->flags & SQLITE_OPEN_READONLY); + int is_readonly = (ctx->flags & SQLITE3_RB_DATABASE_READONLY); - if (isReadonly || ctx->owner == getpid()) { + if (is_readonly || ctx->owner == getpid()) { // Ordinary close. sqlite3_close_v2(ctx->db); ctx->db = NULL; @@ -153,7 +154,9 @@ rb_sqlite3_open_v2(VALUE self, VALUE file, VALUE mode, VALUE zvfs) ); CHECK(ctx->db, status); - ctx->flags = flags; + if (flags & SQLITE_OPEN_READONLY) { + ctx->flags |= SQLITE3_RB_DATABASE_READONLY; + } return self; } @@ -943,11 +946,10 @@ rb_sqlite3_open16(VALUE self, VALUE file) #endif #endif - status = sqlite3_open16(utf16_string_value_ptr(file), &ctx->db); - - // these are the perm flags used implicitly by sqlite3_open16, + // sqlite3_open16 implicitly uses flags (SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE) // see https://www.sqlite.org/capi3ref.html#sqlite3_open - ctx->flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE; + // so we do not ever set SQLITE3_RB_DATABASE_READONLY in ctx->flags + status = sqlite3_open16(utf16_string_value_ptr(file), &ctx->db); CHECK(ctx->db, status) diff --git a/ext/sqlite3/database.h b/ext/sqlite3/database.h index 1ef7b245..04124881 100644 --- a/ext/sqlite3/database.h +++ b/ext/sqlite3/database.h @@ -3,6 +3,10 @@ #include +/* bits in the `flags` field */ +#define SQLITE3_RB_DATABASE_READONLY 0x01 +#define SQLITE3_RB_DATABASE_DISCARDED 0x02 + struct _sqlite3Ruby { sqlite3 *db; VALUE busy_handler; diff --git a/ext/sqlite3/statement.c b/ext/sqlite3/statement.c index 690cd0f8..705b7679 100644 --- a/ext/sqlite3/statement.c +++ b/ext/sqlite3/statement.c @@ -1,22 +1,12 @@ #include #define REQUIRE_OPEN_STMT(_ctxt) \ - if(!_ctxt->st) \ + if (!_ctxt->st) \ rb_raise(rb_path2class("SQLite3::Exception"), "cannot use a closed statement"); -static void -require_open_db(VALUE stmt_rb) -{ - VALUE closed_p = rb_funcall( - rb_iv_get(stmt_rb, "@connection"), - rb_intern("closed?"), 0); - - if (RTEST(closed_p)) { - rb_raise(rb_path2class("SQLite3::Exception"), - "cannot use a statement associated with a closed database"); - } -} - +#define REQUIRE_LIVE_DB(_ctxt) \ + if (_ctxt->db->flags & SQLITE3_RB_DATABASE_DISCARDED) \ + rb_raise(rb_path2class("SQLite3::Exception"), "cannot use a statement associated with a discarded database"); VALUE cSqlite3Statement; @@ -71,6 +61,11 @@ prepare(VALUE self, VALUE db, VALUE sql) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); + /* Dereferencing a pointer to the database struct will be faster than accessing it through the + * instance variable @connection. The struct pointer is guaranteed to be live because instance + * variable will keep it from being GCed. */ + ctx->db = db_ctx; + #ifdef HAVE_SQLITE3_PREPARE_V2 status = sqlite3_prepare_v2( #else @@ -135,7 +130,7 @@ step(VALUE self) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); if (ctx->done_p) { return Qnil; } @@ -232,7 +227,7 @@ bind_param(VALUE self, VALUE key, VALUE value) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); switch (TYPE(key)) { @@ -326,7 +321,7 @@ reset_bang(VALUE self) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); sqlite3_reset(ctx->st); @@ -348,7 +343,7 @@ clear_bindings_bang(VALUE self) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); sqlite3_clear_bindings(ctx->st); @@ -382,7 +377,7 @@ column_count(VALUE self) sqlite3StmtRubyPtr ctx; TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); return INT2NUM(sqlite3_column_count(ctx->st)); @@ -415,7 +410,7 @@ column_name(VALUE self, VALUE index) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); name = sqlite3_column_name(ctx->st, (int)NUM2INT(index)); @@ -440,7 +435,7 @@ column_decltype(VALUE self, VALUE index) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); name = sqlite3_column_decltype(ctx->st, (int)NUM2INT(index)); @@ -459,7 +454,7 @@ bind_parameter_count(VALUE self) sqlite3StmtRubyPtr ctx; TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); return INT2NUM(sqlite3_bind_parameter_count(ctx->st)); @@ -568,7 +563,7 @@ stats_as_hash(VALUE self) sqlite3StmtRubyPtr ctx; TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); VALUE arg = rb_hash_new(); @@ -587,7 +582,7 @@ stat_for(VALUE self, VALUE key) sqlite3StmtRubyPtr ctx; TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); if (SYMBOL_P(key)) { @@ -609,7 +604,7 @@ memused(VALUE self) sqlite3StmtRubyPtr ctx; TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); return INT2NUM(sqlite3_stmt_status(ctx->st, SQLITE_STMTSTATUS_MEMUSED, 0)); @@ -628,7 +623,7 @@ database_name(VALUE self, VALUE index) sqlite3StmtRubyPtr ctx; TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); return SQLITE3_UTF8_STR_NEW2( @@ -647,7 +642,7 @@ get_sql(VALUE self) sqlite3StmtRubyPtr ctx; TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); return rb_obj_freeze(SQLITE3_UTF8_STR_NEW2(sqlite3_sql(ctx->st))); @@ -667,7 +662,7 @@ get_expanded_sql(VALUE self) TypedData_Get_Struct(self, sqlite3StmtRuby, &statement_type, ctx); - require_open_db(self); + REQUIRE_LIVE_DB(ctx); REQUIRE_OPEN_STMT(ctx); expanded_sql = sqlite3_expanded_sql(ctx->st); diff --git a/ext/sqlite3/statement.h b/ext/sqlite3/statement.h index d5dd343f..faae92b2 100644 --- a/ext/sqlite3/statement.h +++ b/ext/sqlite3/statement.h @@ -5,6 +5,7 @@ struct _sqlite3StmtRuby { sqlite3_stmt *st; + sqlite3Ruby *db; int done_p; }; diff --git a/test/test_discarding.rb b/test/test_discarding.rb index d4fd05c9..5877c9a4 100644 --- a/test/test_discarding.rb +++ b/test/test_discarding.rb @@ -160,7 +160,7 @@ def test_a_discarded_connection_with_statements db.send(:discard) e = assert_raises(SQLite3::Exception) { stmt.execute } - assert_match(/cannot use a statement associated with a closed database/, e.message) + assert_match(/cannot use a statement associated with a discarded database/, e.message) assert_nothing_raised { stmt.close } assert_predicate(stmt, :closed?) diff --git a/test/test_statement.rb b/test/test_statement.rb index 7d582bd8..b6a55001 100644 --- a/test/test_statement.rb +++ b/test/test_statement.rb @@ -135,6 +135,13 @@ def test_new_closed_handle end end + def test_closed_db_behavior + @db.close + result = nil + assert_nothing_raised { result = @stmt.execute } + refute_nil result + end + def test_new_with_remainder stmt = SQLite3::Statement.new(@db, "select 'foo';bar") assert_equal "bar", stmt.remainder