diff --git a/enumerator.c b/enumerator.c index b65712fe27..1522a3f699 100644 --- a/enumerator.c +++ b/enumerator.c @@ -105,7 +105,7 @@ VALUE rb_cEnumerator; VALUE rb_cLazy; static ID id_rewind, id_each, id_new, id_initialize, id_yield, id_call, id_size; -static ID id_eqq, id_next, id_result, id_lazy, id_receiver, id_arguments, id_method, id_force; +static ID id_eqq, id_next, id_result, id_lazy, id_receiver, id_arguments, id_memo, id_method, id_force; static VALUE sym_each, sym_cycle; VALUE rb_eStopIteration; @@ -1641,14 +1641,18 @@ lazy_zip(int argc, VALUE *argv, VALUE obj) static VALUE lazy_take_func(VALUE val, VALUE args, int argc, VALUE *argv) { - NODE *memo = RNODE(args); + long remain; + VALUE memo = rb_ivar_get(argv[0], id_memo); + if (NIL_P(memo)) { + memo = args; + } rb_funcall2(argv[0], id_yield, argc - 1, argv + 1); - if (--memo->u3.cnt == 0) { - memo->u3.cnt = memo->u2.argc; + if ((remain = NUM2LONG(memo)-1) == 0) { return Qundef; } else { + rb_ivar_set(argv[0], id_memo, LONG2NUM(remain)); return Qnil; } } @@ -1666,7 +1670,6 @@ lazy_take_size(VALUE lazy) static VALUE lazy_take(VALUE obj, VALUE n) { - NODE *memo; long len = NUM2LONG(n); int argc = 1; VALUE argv[3]; @@ -1680,9 +1683,8 @@ lazy_take(VALUE obj, VALUE n) argv[2] = INT2NUM(0); argc = 3; } - memo = NEW_MEMO(0, len, len); return lazy_set_method(rb_block_call(rb_cLazy, id_new, argc, argv, - lazy_take_func, (VALUE) memo), + lazy_take_func, n), rb_ary_new3(1, n), lazy_take_size); } @@ -1955,6 +1957,7 @@ Init_Enumerator(void) id_eqq = rb_intern("==="); id_receiver = rb_intern("receiver"); id_arguments = rb_intern("arguments"); + id_memo = rb_intern("memo"); id_method = rb_intern("method"); id_force = rb_intern("force"); sym_each = ID2SYM(id_each); diff --git a/test/ruby/test_lazy_enumerator.rb b/test/ruby/test_lazy_enumerator.rb index acd4843afb..35e92c9050 100644 --- a/test/ruby/test_lazy_enumerator.rb +++ b/test/ruby/test_lazy_enumerator.rb @@ -243,6 +243,23 @@ class TestLazyEnumerator < Test::Unit::TestCase assert_equal((1..5).to_a, take5.force, bug6428) end + def test_take_nested + bug7696 = '[ruby-core:51470]' + a = Step.new(1..10) + take5 = a.lazy.take(5) + assert_equal([*(1..5)]*5, take5.flat_map{take5}.force, bug7696) + end + + def test_take_rewound + bug7696 = '[ruby-core:51470]' + e=(1..42).lazy.take(2) + assert_equal 1, e.next + assert_equal 2, e.next + e.rewind + assert_equal 1, e.next + assert_equal 2, e.next + end + def test_take_while a = Step.new(1..10) assert_equal(1, a.take_while {|i| i < 5}.first)