diff --git a/ext/stringio/stringio.c b/ext/stringio/stringio.c index 9273a8effb..ff2a544af6 100644 --- a/ext/stringio/stringio.c +++ b/ext/stringio/stringio.c @@ -989,15 +989,16 @@ bm_search(const char *little, long llen, const char *big, long blen, const long struct getline_arg { VALUE rs; long limit; + unsigned int chomp: 1; }; static struct getline_arg * prepare_getline_args(struct getline_arg *arg, int argc, VALUE *argv) { - VALUE str, lim; + VALUE str, lim, opts; long limit = -1; - rb_scan_args(argc, argv, "02", &str, &lim); + argc = rb_scan_args(argc, argv, "02:", &str, &lim, &opts); switch (argc) { case 0: str = rb_rs; @@ -1023,15 +1024,35 @@ prepare_getline_args(struct getline_arg *arg, int argc, VALUE *argv) } arg->rs = str; arg->limit = limit; + arg->chomp = 0; + if (!NIL_P(opts)) { + static ID keywords[1]; + VALUE vchomp; + if (!keywords[0]) { + keywords[0] = rb_intern_const("chomp"); + } + rb_get_kwargs(opts, keywords, 0, 1, &vchomp); + arg->chomp = (vchomp != Qundef) && RTEST(vchomp); + } return arg; } +static inline int +chomp_newline_width(const char *s, const char *e) +{ + if (e > s && *--e == '\n') { + return 1; + } + return 0; +} + static VALUE strio_getline(struct getline_arg *arg, struct StringIO *ptr) { const char *s, *e, *p; long n, limit = arg->limit; VALUE str = arg->rs; + int w = 0; if (ptr->pos >= (n = RSTRING_LEN(ptr->string))) { return Qnil; @@ -1043,7 +1064,10 @@ strio_getline(struct getline_arg *arg, struct StringIO *ptr) e = rb_enc_right_char_head(s, s + limit, e, get_enc(ptr)); } if (NIL_P(str)) { - str = strio_substr(ptr, ptr->pos, e - s); + if (arg->chomp) { + w = chomp_newline_width(s, e); + } + str = strio_substr(ptr, ptr->pos, e - s - w); } else if ((n = RSTRING_LEN(str)) == 0) { p = s; @@ -1056,23 +1080,28 @@ strio_getline(struct getline_arg *arg, struct StringIO *ptr) while ((p = memchr(p, '\n', e - p)) && (p != e)) { if (*++p == '\n') { e = p + 1; + w = (arg->chomp ? 1 : 0); break; } } - str = strio_substr(ptr, s - RSTRING_PTR(ptr->string), e - s); + if (!w && arg->chomp) { + w = chomp_newline_width(s, e); + } + str = strio_substr(ptr, s - RSTRING_PTR(ptr->string), e - s - w); } else if (n == 1) { if ((p = memchr(s, RSTRING_PTR(str)[0], e - s)) != 0) { e = p + 1; + w = (arg->chomp ? 1 : 0); } - str = strio_substr(ptr, ptr->pos, e - s); + str = strio_substr(ptr, ptr->pos, e - s - w); } else { if (n < e - s) { if (e - s < 1024) { for (p = s; p + n <= e; ++p) { if (MEMCMP(p, RSTRING_PTR(str), char, n) == 0) { - e = p + n; + e = p + (arg->chomp ? 0 : n); break; } } @@ -1082,11 +1111,11 @@ strio_getline(struct getline_arg *arg, struct StringIO *ptr) p = RSTRING_PTR(str); bm_init_skip(skip, p, n); if ((pos = bm_search(p, n, s, e - s, skip)) >= 0) { - e = s + pos + n; + e = s + pos + (arg->chomp ? 0 : n); } } } - str = strio_substr(ptr, ptr->pos, e - s); + str = strio_substr(ptr, ptr->pos, e - s - w); } ptr->pos = e - RSTRING_PTR(ptr->string); ptr->lineno++; diff --git a/test/stringio/test_stringio.rb b/test/stringio/test_stringio.rb index fc3b6baebb..07435008ad 100644 --- a/test/stringio/test_stringio.rb +++ b/test/stringio/test_stringio.rb @@ -81,6 +81,21 @@ class TestStringIO < Test::Unit::TestCase assert_nothing_raised {StringIO.new("").gets(nil, nil)} end + def test_gets_chomp + assert_equal(nil, StringIO.new("").gets(chomp: true)) + assert_equal("", StringIO.new("\n").gets(chomp: true)) + assert_equal("a", StringIO.new("a\n").gets(chomp: true)) + assert_equal("a", StringIO.new("a\nb\n").gets(chomp: true)) + assert_equal("a", StringIO.new("a").gets(chomp: true)) + assert_equal("a", StringIO.new("a\nb").gets(chomp: true)) + assert_equal("abc", StringIO.new("abc\n\ndef\n").gets(chomp: true)) + assert_equal("abc\n\ndef", StringIO.new("abc\n\ndef\n").gets(nil, chomp: true)) + assert_equal("abc\n", StringIO.new("abc\n\ndef\n").gets("", chomp: true)) + stringio = StringIO.new("abc\n\ndef\n") + assert_equal("abc\n", stringio.gets("", chomp: true)) + assert_equal("def", stringio.gets("", chomp: true)) + end + def test_readlines assert_equal([], StringIO.new("").readlines) assert_equal(["\n"], StringIO.new("\n").readlines) @@ -476,8 +491,12 @@ class TestStringIO < Test::Unit::TestCase def test_each f = StringIO.new("foo\nbar\nbaz\n") assert_equal(["foo\n", "bar\n", "baz\n"], f.each.to_a) + f.rewind + assert_equal(["foo", "bar", "baz"], f.each(chomp: true).to_a) f = StringIO.new("foo\nbar\n\nbaz\n") assert_equal(["foo\nbar\n\n", "baz\n"], f.each("").to_a) + f.rewind + assert_equal(["foo\nbar\n", "baz"], f.each("", chomp: true).to_a) end def test_putc