diff --git a/ext/mysql2/statement.c b/ext/mysql2/statement.c index 3a022e661..98b7e6252 100644 --- a/ext/mysql2/statement.c +++ b/ext/mysql2/statement.c @@ -2,7 +2,7 @@ VALUE cMysql2Statement; extern VALUE mMysql2, cMysql2Error, cBigDecimal, cDateTime, cDate; -static VALUE sym_stream, intern_new_with_args, intern_each, intern_to_s; +static VALUE sym_stream, intern_new_with_args, intern_each, intern_to_s, intern_merge_bang; static VALUE intern_sec_fraction, intern_usec, intern_sec, intern_min, intern_hour, intern_day, intern_month, intern_year; #define GET_STATEMENT(self) \ @@ -184,7 +184,7 @@ static void set_buffer_for_string(MYSQL_BIND* bind_buffer, unsigned long *length * the buffer is a Ruby string pointer and not our memory to manage. */ #define FREE_BINDS \ - for (i = 0; i < argc; i++) { \ + for (i = 0; i < c; i++) { \ if (bind_buffers[i].buffer && NIL_P(params_enc[i])) { \ xfree(bind_buffers[i].buffer); \ } \ @@ -248,8 +248,10 @@ static VALUE rb_mysql_stmt_execute(int argc, VALUE *argv, VALUE self) { unsigned long *length_buffers = NULL; unsigned long bind_count; long i; + int c; MYSQL_STMT *stmt; MYSQL_RES *metadata; + VALUE opts; VALUE current; VALUE resultObj; VALUE *params_enc; @@ -261,14 +263,17 @@ static VALUE rb_mysql_stmt_execute(int argc, VALUE *argv, VALUE self) { conn_enc = rb_to_encoding(wrapper->encoding); - /* Scratch space for string encoding exports, allocate on the stack. */ - params_enc = alloca(sizeof(VALUE) * argc); + // Get count of ordinary arguments, and extract hash opts/keyword arguments + c = rb_scan_args(argc, argv, "*:", NULL, &opts); + + // Scratch space for string encoding exports, allocate on the stack + params_enc = alloca(sizeof(VALUE) * c); stmt = stmt_wrapper->stmt; bind_count = mysql_stmt_param_count(stmt); - if (argc != (long)bind_count) { - rb_raise(cMysql2Error, "Bind parameter count (%ld) doesn't match number of arguments (%d)", bind_count, argc); + if (c != (long)bind_count) { + rb_raise(cMysql2Error, "Bind parameter count (%ld) doesn't match number of arguments (%d)", bind_count, c); } // setup any bind variables in the query @@ -276,7 +281,7 @@ static VALUE rb_mysql_stmt_execute(int argc, VALUE *argv, VALUE self) { bind_buffers = xcalloc(bind_count, sizeof(MYSQL_BIND)); length_buffers = xcalloc(bind_count, sizeof(unsigned long)); - for (i = 0; i < argc; i++) { + for (i = 0; i < c; i++) { bind_buffers[i].buffer = NULL; params_enc[i] = Qnil; @@ -416,10 +421,16 @@ static VALUE rb_mysql_stmt_execute(int argc, VALUE *argv, VALUE self) { return Qnil; } + // Important to duplicate the hash, will receive merge! if extra opts current = rb_hash_dup(rb_iv_get(stmt_wrapper->client, "@query_options")); (void)RB_GC_GUARD(current); Check_Type(current, T_HASH); + // Merge in hash opts/keyword arguments + if (!NIL_P(opts)) { + rb_funcall(current, intern_merge_bang, 1, opts); + } + is_streaming = (Qtrue == rb_hash_aref(current, sym_stream)); if (!is_streaming) { // recieve the whole result set from the server @@ -562,4 +573,5 @@ void init_mysql2_statement() { intern_year = rb_intern("year"); intern_to_s = rb_intern("to_s"); + intern_merge_bang = rb_intern("merge!"); } diff --git a/lib/mysql2/statement.rb b/lib/mysql2/statement.rb index 482ccce19..d4510a139 100644 --- a/lib/mysql2/statement.rb +++ b/lib/mysql2/statement.rb @@ -5,14 +5,14 @@ class Statement include Enumerable if Thread.respond_to?(:handle_interrupt) - def execute(*args) + def execute(*args, **kwargs) Thread.handle_interrupt(::Mysql2::Util::TIMEOUT_ERROR_CLASS => :never) do - _execute(*args) + _execute(*args, **kwargs) end end else - def execute(*args) - _execute(*args) + def execute(*args, **kwargs) + _execute(*args, **kwargs) end end end diff --git a/spec/mysql2/statement_spec.rb b/spec/mysql2/statement_spec.rb index 7d856cbd3..a6a14aa62 100644 --- a/spec/mysql2/statement_spec.rb +++ b/spec/mysql2/statement_spec.rb @@ -88,6 +88,20 @@ def stmt_count expect(result.to_a).to eq(['max1' => int64_max1, 'max2' => int64_max2, 'max3' => int64_max3, 'min1' => int64_min1, 'min2' => int64_min2, 'min3' => int64_min3]) end + it "should accept keyword arguments on statement execute" do + stmt = @client.prepare 'SELECT 1 AS a' + + expect(stmt.execute(as: :hash).first).to eq("a" => 1) + expect(stmt.execute(as: :array).first).to eq([1]) + end + + it "should accept bind arguments and keyword arguments on statement execute" do + stmt = @client.prepare 'SELECT ? AS a' + + expect(stmt.execute(1, as: :hash).first).to eq("a" => 1) + expect(stmt.execute(1, as: :array).first).to eq([1]) + end + it "should keep its result after other query" do @client.query 'USE test' @client.query 'CREATE TABLE IF NOT EXISTS mysql2_stmt_q(a int)' @@ -188,10 +202,9 @@ def stmt_count end it "should warn but still work if cache_rows is set to false" do - @client.query_options[:cache_rows] = false statement = @client.prepare 'SELECT 1' result = nil - expect { result = statement.execute.to_a }.to output(/:cache_rows is forced for prepared statements/).to_stderr + expect { result = statement.execute(cache_rows: false).to_a }.to output(/:cache_rows is forced for prepared statements/).to_stderr expect(result.length).to eq(1) end @@ -240,10 +253,7 @@ def stmt_count it "should be able to stream query result" do n = 1 stmt = @client.prepare("SELECT 1 UNION SELECT 2") - - @client.query_options.merge!(stream: true, cache_rows: false, as: :array) - - stmt.execute.each do |r| + stmt.execute(stream: true, cache_rows: false, as: :array).each do |r| case n when 1 expect(r).to eq([1]) @@ -269,23 +279,17 @@ def stmt_count end it "should yield rows as hash's with symbol keys if :symbolize_keys was set to true" do - @client.query_options[:symbolize_keys] = true - @result = @client.prepare("SELECT 1").execute + @result = @client.prepare("SELECT 1").execute(symbolize_keys: true) @result.each do |row| expect(row.keys.first).to be_an_instance_of(Symbol) end - @client.query_options[:symbolize_keys] = false end it "should be able to return results as an array" do - @client.query_options[:as] = :array - - @result = @client.prepare("SELECT 1").execute + @result = @client.prepare("SELECT 1").execute(as: :array) @result.each do |row| expect(row).to be_an_instance_of(Array) end - - @client.query_options[:as] = :hash end it "should cache previously yielded results by default" do @@ -294,35 +298,21 @@ def stmt_count end it "should yield different value for #first if streaming" do - @client.query_options[:stream] = true - @client.query_options[:cache_rows] = false - - result = @client.prepare("SELECT 1 UNION SELECT 2").execute + result = @client.prepare("SELECT 1 UNION SELECT 2").execute(stream: true, cache_rows: true) expect(result.first).not_to eql(result.first) - - @client.query_options[:stream] = false - @client.query_options[:cache_rows] = true end it "should yield the same value for #first if streaming is disabled" do - @client.query_options[:stream] = false - result = @client.prepare("SELECT 1 UNION SELECT 2").execute + result = @client.prepare("SELECT 1 UNION SELECT 2").execute(stream: false) expect(result.first).to eql(result.first) end it "should throw an exception if we try to iterate twice when streaming is enabled" do - @client.query_options[:stream] = true - @client.query_options[:cache_rows] = false - - result = @client.prepare("SELECT 1 UNION SELECT 2").execute - + result = @client.prepare("SELECT 1 UNION SELECT 2").execute(stream: true, cache_rows: false) expect do result.each {} result.each {} end.to raise_exception(Mysql2::Error) - - @client.query_options[:stream] = false - @client.query_options[:cache_rows] = true end end @@ -371,21 +361,20 @@ def stmt_count context "cast booleans for TINYINT if :cast_booleans is enabled" do # rubocop:disable Style/Semicolon - let(:client) { new_client(cast_booleans: true) } - let(:id1) { client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 1)'; client.last_id } - let(:id2) { client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 0)'; client.last_id } - let(:id3) { client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES (-1)'; client.last_id } + let(:id1) { @client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 1)'; @client.last_id } + let(:id2) { @client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES ( 0)'; @client.last_id } + let(:id3) { @client.query 'INSERT INTO mysql2_test (bool_cast_test) VALUES (-1)'; @client.last_id } # rubocop:enable Style/Semicolon after do - client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2},#{id3})" + @client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2},#{id3})" end it "should return TrueClass or FalseClass for a TINYINT value if :cast_booleans is enabled" do - query = client.prepare 'SELECT bool_cast_test FROM mysql2_test WHERE id = ?' - result1 = query.execute id1 - result2 = query.execute id2 - result3 = query.execute id3 + query = @client.prepare 'SELECT bool_cast_test FROM mysql2_test WHERE id = ?' + result1 = query.execute id1, cast_booleans: true + result2 = query.execute id2, cast_booleans: true + result3 = query.execute id3, cast_booleans: true expect(result1.first['bool_cast_test']).to be true expect(result2.first['bool_cast_test']).to be false expect(result3.first['bool_cast_test']).to be true @@ -394,19 +383,18 @@ def stmt_count context "cast booleans for BIT(1) if :cast_booleans is enabled" do # rubocop:disable Style/Semicolon - let(:client) { new_client(cast_booleans: true) } - let(:id1) { client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (1)'; client.last_id } - let(:id2) { client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (0)'; client.last_id } + let(:id1) { @client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (1)'; @client.last_id } + let(:id2) { @client.query 'INSERT INTO mysql2_test (single_bit_test) VALUES (0)'; @client.last_id } # rubocop:enable Style/Semicolon after do - client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2})" + @client.query "DELETE from mysql2_test WHERE id IN(#{id1},#{id2})" end it "should return TrueClass or FalseClass for a BIT(1) value if :cast_booleans is enabled" do - query = client.prepare 'SELECT single_bit_test FROM mysql2_test WHERE id = ?' - result1 = query.execute id1 - result2 = query.execute id2 + query = @client.prepare 'SELECT single_bit_test FROM mysql2_test WHERE id = ?' + result1 = query.execute id1, cast_booleans: true + result2 = query.execute id2, cast_booleans: true expect(result1.first['single_bit_test']).to be true expect(result2.first['single_bit_test']).to be false end