1
0
Fork 0
mirror of https://github.com/ruby/ruby.git synced 2022-11-09 12:17:21 -05:00

Implement Enumerator#+ and Enumerable#chain [Feature #15144]

They return an Enumerator::Chain object which is a subclass of
Enumerator, which represents a chain of enumerables that works as a
single enumerator.

```ruby
e = (1..3).chain([4, 5])
e.to_a #=> [1, 2, 3, 4, 5]

e = (1..3).each + [4, 5]
e.to_a #=> [1, 2, 3, 4, 5]
```

git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@65949 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
This commit is contained in:
knu 2018-11-24 08:38:35 +00:00
parent c0e20037f3
commit 045b0e54d8
2 changed files with 429 additions and 1 deletions

View file

@ -12,6 +12,7 @@
************************************************/
#include "ruby/ruby.h"
#include "internal.h"
#include "id.h"
@ -161,6 +162,13 @@ struct proc_entry {
static VALUE generator_allocate(VALUE klass);
static VALUE generator_init(VALUE obj, VALUE proc);
static VALUE rb_cEnumChain;
struct enum_chain {
VALUE enums;
long pos;
};
static VALUE rb_cArithSeq;
/*
@ -2411,6 +2419,300 @@ stop_result(VALUE self)
return rb_attr_get(self, id_result);
}
/*
* Document-class: Enumerator::Chain
*
* Enumerator::Chain is a subclass of Enumerator, which represents a
* chain of enumerables that works as a single enumerator.
*
* This type of objects can be created by Enumerable#chain and
* Enumerator#+.
*/
static void
enum_chain_mark(void *p)
{
struct enum_chain *ptr = p;
rb_gc_mark(ptr->enums);
}
#define enum_chain_free RUBY_TYPED_DEFAULT_FREE
static size_t
enum_chain_memsize(const void *p)
{
return sizeof(struct enum_chain);
}
static const rb_data_type_t enum_chain_data_type = {
"chain",
{
enum_chain_mark,
enum_chain_free,
enum_chain_memsize,
},
0, 0, RUBY_TYPED_FREE_IMMEDIATELY
};
static struct enum_chain *
enum_chain_ptr(VALUE obj)
{
struct enum_chain *ptr;
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
if (!ptr || ptr->enums == Qundef) {
rb_raise(rb_eArgError, "uninitialized chain");
}
return ptr;
}
/* :nodoc: */
static VALUE
enum_chain_allocate(VALUE klass)
{
struct enum_chain *ptr;
VALUE obj;
obj = TypedData_Make_Struct(klass, struct enum_chain, &enum_chain_data_type, ptr);
ptr->enums = Qundef;
ptr->pos = -1;
return obj;
}
/*
* call-seq:
* Enumerator::Chain.new(*enums) -> enum
*
* Generates a new enumerator object that iterates over the elements
* of given enumerable objects in sequence.
*
* e = Enumerator::Chain.new(1..3, [4, 5])
* e.to_a #=> [1, 2, 3, 4, 5]
* e.size #=> 5
*/
static VALUE
enum_chain_initialize(VALUE obj, VALUE enums)
{
struct enum_chain *ptr;
rb_check_frozen(obj);
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
if (!ptr) rb_raise(rb_eArgError, "unallocated chain");
ptr->enums = rb_obj_freeze(enums);
ptr->pos = -1;
return obj;
}
/* :nodoc: */
static VALUE
enum_chain_init_copy(VALUE obj, VALUE orig)
{
struct enum_chain *ptr0, *ptr1;
if (!OBJ_INIT_COPY(obj, orig)) return obj;
ptr0 = enum_chain_ptr(orig);
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr1);
if (!ptr1) rb_raise(rb_eArgError, "unallocated chain");
ptr1->enums = ptr0->enums;
ptr1->pos = ptr0->pos;
return obj;
}
static VALUE
enum_chain_total_size(VALUE enums)
{
VALUE total = INT2FIX(0);
RARRAY_PTR_USE(enums, ptr, {
long i;
for (i = 0; i < RARRAY_LEN(enums); i++) {
VALUE size = enum_size(ptr[i]);
if (NIL_P(size) || (RB_TYPE_P(size, T_FLOAT) && isinf(NUM2DBL(size)))) {
return size;
}
if (!RB_INTEGER_TYPE_P(size)) {
return Qnil;
}
total = rb_funcall(total, '+', 1, size);
}
});
return total;
}
/*
* call-seq:
* obj.size -> integer
*
* Returns the total size of the enumerator chain calculated by
* summing up the size of each enumerable in the chain. If any of the
* enumerables reports its size as nil or Float::INFINITY, that value
* is returned as the total size.
*/
static VALUE
enum_chain_size(VALUE obj)
{
return enum_chain_total_size(enum_chain_ptr(obj)->enums);
}
static VALUE
enum_chain_enum_size(VALUE obj, VALUE args, VALUE eobj)
{
return enum_chain_size(obj);
}
static VALUE
enum_chain_yield_block(VALUE arg, VALUE block, int argc, VALUE *argv)
{
return rb_funcallv(block, rb_intern("call"), argc, argv);
}
static VALUE
enum_chain_enum_no_size(VALUE obj, VALUE args, VALUE eobj)
{
return Qnil;
}
/*
* call-seq:
* obj.each(*args) { |...| ... } -> obj
* obj.each(*args) -> enumerator
*
* Iterates over the elements of the first enumerable by calling the
* "each" method on it with the given arguments, then proceeds to the
* following enumerables in sequence until all of the enumerables are
* exhausted.
*
* If no block is given, returns an enumerator.
*/
static VALUE
enum_chain_each(int argc, VALUE *argv, VALUE obj)
{
VALUE enums, block;
struct enum_chain *objptr;
RETURN_SIZED_ENUMERATOR(obj, argc, argv, argc > 0 ? enum_chain_enum_no_size : enum_chain_enum_size);
objptr = enum_chain_ptr(obj);
enums = objptr->enums;
block = rb_block_proc();
RARRAY_PTR_USE(enums, ptr, {
long i;
for (i = 0; i < RARRAY_LEN(enums); i++) {
objptr->pos = i;
rb_block_call(ptr[i], id_each, argc, argv, enum_chain_yield_block, block);
}
});
return obj;
}
/*
* call-seq:
* obj.rewind -> obj
*
* Rewinds the enumerator chain by calling the "rewind" method on each
* enumerable in reverse order. Each call is performed only if the
* enumerable responds to the method.
*/
static VALUE
enum_chain_rewind(VALUE obj)
{
struct enum_chain *objptr = enum_chain_ptr(obj);
VALUE enums = objptr->enums;
RARRAY_PTR_USE(enums, ptr, {
long i;
for (i = objptr->pos; 0 <= i && i < RARRAY_LEN(enums); objptr->pos = --i) {
rb_check_funcall(ptr[i], id_rewind, 0, 0);
}
});
return obj;
}
static VALUE
inspect_enum_chain(VALUE obj, VALUE dummy, int recur)
{
VALUE klass = rb_obj_class(obj);
struct enum_chain *ptr;
TypedData_Get_Struct(obj, struct enum_chain, &enum_chain_data_type, ptr);
if (!ptr || ptr->enums == Qundef) {
return rb_sprintf("#<%"PRIsVALUE": uninitialized>", rb_class_path(klass));
}
if (recur) {
return rb_sprintf("#<%"PRIsVALUE": ...>", rb_class_path(klass));
}
return rb_sprintf("#<%"PRIsVALUE": %+"PRIsVALUE">", rb_class_path(klass), ptr->enums);
}
/*
* call-seq:
* obj.inspect -> string
*
* Returns a printable version of the enumerator chain.
*/
static VALUE
enum_chain_inspect(VALUE obj)
{
return rb_exec_recursive(inspect_enum_chain, obj, 0);
}
/*
* call-seq:
* e.chain(*enums) -> enumerator
*
* Returns an enumerator object generated from this enumerator and
* given enumerables.
*
* e = (1..3).chain([4, 5])
* e.to_a #=> [1, 2, 3, 4, 5]
*/
static VALUE
enum_chain(int argc, VALUE *argv, VALUE obj)
{
VALUE enums = rb_ary_new_from_values(1, &obj);
rb_ary_cat(enums, argv, argc);
return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
}
/*
* call-seq:
* e + enum -> enumerator
*
* Returns an enumerator object generated from this enumerator and a
* given enumerable.
*
* e = (1..3).each + [4, 5]
* e.to_a #=> [1, 2, 3, 4, 5]
*/
static VALUE
enumerator_plus(VALUE obj, VALUE eobj)
{
VALUE enums = rb_ary_new_from_args(2, obj, eobj);
return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums);
}
/*
* Document-class: Enumerator::ArithmeticSequence
*
@ -2907,6 +3209,8 @@ InitVM_Enumerator(void)
rb_define_method(rb_cEnumerator, "rewind", enumerator_rewind, 0);
rb_define_method(rb_cEnumerator, "inspect", enumerator_inspect, 0);
rb_define_method(rb_cEnumerator, "size", enumerator_size, 0);
rb_define_method(rb_cEnumerator, "+", enumerator_plus, 1);
rb_define_method(rb_mEnumerable, "chain", enum_chain, -1);
/* Lazy */
rb_cLazy = rb_define_class_under(rb_cEnumerator, "Lazy", rb_cEnumerator);
@ -2960,6 +3264,16 @@ InitVM_Enumerator(void)
rb_define_method(rb_cYielder, "yield", yielder_yield, -2);
rb_define_method(rb_cYielder, "<<", yielder_yield_push, 1);
/* Chain */
rb_cEnumChain = rb_define_class_under(rb_cEnumerator, "Chain", rb_cEnumerator);
rb_define_alloc_func(rb_cEnumChain, enum_chain_allocate);
rb_define_method(rb_cEnumChain, "initialize", enum_chain_initialize, -2);
rb_define_method(rb_cEnumChain, "initialize_copy", enum_chain_init_copy, 1);
rb_define_method(rb_cEnumChain, "each", enum_chain_each, -1);
rb_define_method(rb_cEnumChain, "size", enum_chain_size, 0);
rb_define_method(rb_cEnumChain, "rewind", enum_chain_rewind, 0);
rb_define_method(rb_cEnumChain, "inspect", enum_chain_inspect, 0);
/* ArithmeticSequence */
rb_cArithSeq = rb_define_class_under(rb_cEnumerator, "ArithmeticSequence", rb_cEnumerator);
rb_undef_alloc_func(rb_cArithSeq);

View file

@ -670,5 +670,119 @@ class TestEnumerator < Test::Unit::TestCase
assert_equal([0, 1], u.force)
assert_equal([0, 1], u.force)
end
end
def test_enum_chain_and_plus
r = 1..5
e1 = r.chain()
assert_kind_of(Enumerator::Chain, e1)
assert_equal(5, e1.size)
ary = []
e1.each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5], ary)
e2 = r.chain([6, 7, 8])
assert_kind_of(Enumerator::Chain, e2)
assert_equal(8, e2.size)
ary = []
e2.each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
e3 = r.chain([6, 7], 8.step)
assert_kind_of(Enumerator::Chain, e3)
assert_equal(Float::INFINITY, e3.size)
ary = []
e3.take(10).each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
# `a + b + c` should not return `Enumerator::Chain.new(a, b, c)`
# because it is expected that `(a + b).each` be called.
e4 = e2.dup
class << e4
attr_reader :each_is_called
def each
super
@each_is_called = true
end
end
e5 = e4 + 9.step
assert_kind_of(Enumerator::Chain, e5)
assert_equal(Float::INFINITY, e5.size)
ary = []
e5.take(10).each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
assert_equal(true, e4.each_is_called)
end
def test_chained_enums
a = (1..5).each
e0 = Enumerator::Chain.new()
assert_kind_of(Enumerator::Chain, e0)
assert_equal(0, e0.size)
ary = []
e0.each { |x| ary << x }
assert_equal([], ary)
e1 = Enumerator::Chain.new(a)
assert_kind_of(Enumerator::Chain, e1)
assert_equal(5, e1.size)
ary = []
e1.each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5], ary)
e2 = Enumerator::Chain.new(a, [6, 7, 8])
assert_kind_of(Enumerator::Chain, e2)
assert_equal(8, e2.size)
ary = []
e2.each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
e3 = Enumerator::Chain.new(a, [6, 7], 8.step)
assert_kind_of(Enumerator::Chain, e3)
assert_equal(Float::INFINITY, e3.size)
ary = []
e3.take(10).each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], ary)
e4 = Enumerator::Chain.new(a, Enumerator.new { |y| y << 6 << 7 << 8 })
assert_kind_of(Enumerator::Chain, e4)
assert_equal(nil, e4.size)
ary = []
e4.each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5, 6, 7, 8], ary)
e5 = Enumerator::Chain.new(e1, e2)
assert_kind_of(Enumerator::Chain, e5)
assert_equal(13, e5.size)
ary = []
e5.each { |x| ary << x }
assert_equal([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8], ary)
rewound = []
e1.define_singleton_method(:rewind) { rewound << object_id }
e2.define_singleton_method(:rewind) { rewound << object_id }
e5.rewind
assert_equal(rewound, [e2.object_id, e1.object_id])
rewound = []
a = [1]
e6 = Enumerator::Chain.new(a)
a.define_singleton_method(:rewind) { rewound << object_id }
e6.rewind
assert_equal(rewound, [])
assert_equal(
'#<Enumerator::Chain: [' +
'#<Enumerator::Chain: [' +
'#<Enumerator: 1..5:each>' +
']>, ' +
'#<Enumerator::Chain: [' +
'#<Enumerator: 1..5:each>, ' +
'[6, 7, 8]' +
']>' +
']>',
e5.inspect
)
end
end