Skip to content

Commit

Permalink
Add missing row separator encoding conversion (#69)
Browse files Browse the repository at this point in the history
The conversion logic is borrowed from ruby/ruby's io.c:
https://github.com/ruby/ruby/blob/40391faeab608665da87a05c686c074f91a5a206/io.c#L4059-L4079

Fix #68

Reported by IWAMOTO Kouichi. Thanks!!!
  • Loading branch information
kou authored Nov 8, 2023
1 parent d4f7db2 commit 4b170c1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 22 deletions.
63 changes: 41 additions & 22 deletions ext/stringio/stringio.c
Original file line number Diff line number Diff line change
Expand Up @@ -1143,38 +1143,57 @@ struct getline_arg {
};

static struct getline_arg *
prepare_getline_args(struct getline_arg *arg, int argc, VALUE *argv)
prepare_getline_args(struct StringIO *ptr, struct getline_arg *arg, int argc, VALUE *argv)
{
VALUE str, lim, opts;
VALUE rs, lim, opts;
long limit = -1;
int respect_chomp;

argc = rb_scan_args(argc, argv, "02:", &str, &lim, &opts);
respect_chomp = argc == 0 || !NIL_P(str);
argc = rb_scan_args(argc, argv, "02:", &rs, &lim, &opts);
respect_chomp = argc == 0 || !NIL_P(rs);
switch (argc) {
case 0:
str = rb_rs;
rs = rb_rs;
break;

case 1:
if (!NIL_P(str) && !RB_TYPE_P(str, T_STRING)) {
VALUE tmp = rb_check_string_type(str);
if (!NIL_P(rs) && !RB_TYPE_P(rs, T_STRING)) {
VALUE tmp = rb_check_string_type(rs);
if (NIL_P(tmp)) {
limit = NUM2LONG(str);
str = rb_rs;
limit = NUM2LONG(rs);
rs = rb_rs;
}
else {
str = tmp;
rs = tmp;
}
}
break;

case 2:
if (!NIL_P(str)) StringValue(str);
if (!NIL_P(rs)) StringValue(rs);
if (!NIL_P(lim)) limit = NUM2LONG(lim);
break;
}
arg->rs = str;
if (!NIL_P(rs)) {
rb_encoding *enc_rs, *enc_io;
enc_rs = rb_enc_get(rs);
enc_io = get_enc(ptr);
if (enc_rs != enc_io &&
(rb_enc_str_coderange(rs) != ENC_CODERANGE_7BIT ||
(RSTRING_LEN(rs) > 0 && !rb_enc_asciicompat(enc_io)))) {
if (rs == rb_rs) {
rs = rb_enc_str_new(0, 0, enc_io);
rb_str_buf_cat_ascii(rs, "\n");
rs = rs;
}
else {
rb_raise(rb_eArgError, "encoding mismatch: %s IO with %s RS",
rb_enc_name(enc_io),
rb_enc_name(enc_rs));
}
}
}
arg->rs = rs;
arg->limit = limit;
arg->chomp = 0;
if (!NIL_P(opts)) {
Expand Down Expand Up @@ -1302,15 +1321,15 @@ strio_getline(struct getline_arg *arg, struct StringIO *ptr)
static VALUE
strio_gets(int argc, VALUE *argv, VALUE self)
{
struct StringIO *ptr = readable(self);
struct getline_arg arg;
VALUE str;

if (prepare_getline_args(&arg, argc, argv)->limit == 0) {
struct StringIO *ptr = readable(self);
if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) {
return rb_enc_str_new(0, 0, get_enc(ptr));
}

str = strio_getline(&arg, readable(self));
str = strio_getline(&arg, ptr);
rb_lastline_set(str);
return str;
}
Expand Down Expand Up @@ -1347,16 +1366,16 @@ static VALUE
strio_each(int argc, VALUE *argv, VALUE self)
{
VALUE line;
struct StringIO *ptr = readable(self);
struct getline_arg arg;

StringIO(self);
RETURN_ENUMERATOR(self, argc, argv);

if (prepare_getline_args(&arg, argc, argv)->limit == 0) {
if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) {
rb_raise(rb_eArgError, "invalid limit: 0 for each_line");
}

while (!NIL_P(line = strio_getline(&arg, readable(self)))) {
while (!NIL_P(line = strio_getline(&arg, ptr))) {
rb_yield(line);
}
return self;
Expand All @@ -1374,15 +1393,15 @@ static VALUE
strio_readlines(int argc, VALUE *argv, VALUE self)
{
VALUE ary, line;
struct StringIO *ptr = readable(self);
struct getline_arg arg;

StringIO(self);
ary = rb_ary_new();
if (prepare_getline_args(&arg, argc, argv)->limit == 0) {
if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) {
rb_raise(rb_eArgError, "invalid limit: 0 for readlines");
}

while (!NIL_P(line = strio_getline(&arg, readable(self)))) {
ary = rb_ary_new();
while (!NIL_P(line = strio_getline(&arg, ptr))) {
rb_ary_push(ary, line);
}
return ary;
Expand Down
8 changes: 8 additions & 0 deletions test/stringio/test_stringio.rb
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def test_gets
assert_string("", Encoding::UTF_8, StringIO.new("foo").gets(0))
end

def test_gets_utf_16
stringio = StringIO.new("line1\nline2\nline3\n".encode("utf-16le"))
assert_equal("line1\n".encode("utf-16le"), stringio.gets)
assert_equal("line2\n".encode("utf-16le"), stringio.gets)
assert_equal("line3\n".encode("utf-16le"), stringio.gets)
assert_nil(stringio.gets)
end

def test_gets_chomp
assert_equal(nil, StringIO.new("").gets(chomp: true))
assert_equal("", StringIO.new("\n").gets(chomp: true))
Expand Down

0 comments on commit 4b170c1

Please sign in to comment.