From a1fda16b238f24cf55814ecc18f716cbfff8dd91 Mon Sep 17 00:00:00 2001 From: Dylan Thacker-Smith Date: Fri, 27 Sep 2019 12:24:25 -0400 Subject: [PATCH] Optimize Array#flatten and flatten! for already flattened arrays (#2495) * Optimize Array#flatten and flatten! for already flattened arrays * Add benchmark for Array#flatten and Array#flatten! [Bug #16119] --- array.c | 43 ++++++++++++++++++++++++++++--------- benchmark/array_flatten.yml | 19 ++++++++++++++++ 2 files changed, 52 insertions(+), 10 deletions(-) create mode 100644 benchmark/array_flatten.yml diff --git a/array.c b/array.c index 825d9f7126..37456147b1 100644 --- a/array.c +++ b/array.c @@ -5122,21 +5122,43 @@ rb_ary_count(int argc, VALUE *argv, VALUE ary) } static VALUE -flatten(VALUE ary, int level, int *modified) +flatten(VALUE ary, int level) { - long i = 0; + long i; VALUE stack, result, tmp, elt, vmemo; st_table *memo; st_data_t id; - stack = ary_new(0, ARY_DEFAULT_SIZE); + for (i = 0; i < RARRAY_LEN(ary); i++) { + elt = RARRAY_AREF(ary, i); + tmp = rb_check_array_type(elt); + if (!NIL_P(tmp)) { + break; + } + } + if (i == RARRAY_LEN(ary)) { + return ary; + } else if (tmp == ary) { + rb_raise(rb_eArgError, "tried to flatten recursive array"); + } + result = ary_new(0, RARRAY_LEN(ary)); + ary_memcpy(result, 0, i, RARRAY_CONST_PTR_TRANSIENT(ary)); + ARY_SET_LEN(result, i); + + stack = ary_new(0, ARY_DEFAULT_SIZE); + rb_ary_push(stack, ary); + rb_ary_push(stack, LONG2NUM(i + 1)); + vmemo = rb_hash_new(); RBASIC_CLEAR_CLASS(vmemo); memo = st_init_numtable(); rb_hash_st_table_set(vmemo, memo); st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue); - *modified = 0; + st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue); + + ary = tmp; + i = 0; while (1) { while (i < RARRAY_LEN(ary)) { @@ -5155,7 +5177,6 @@ flatten(VALUE ary, int level, int *modified) rb_ary_push(result, elt); } else { - *modified = 1; id = (st_data_t)tmp; if (st_lookup(memo, id, 0)) { st_clear(memo); @@ -5215,9 +5236,8 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary) if (!NIL_P(lv)) level = NUM2INT(lv); if (level == 0) return Qnil; - result = flatten(ary, level, &mod); - if (mod == 0) { - ary_discard(result); + result = flatten(ary, level); + if (result == ary) { return Qnil; } if (!(mod = ARY_EMBED_P(result))) rb_obj_freeze(result); @@ -5252,7 +5272,7 @@ rb_ary_flatten_bang(int argc, VALUE *argv, VALUE ary) static VALUE rb_ary_flatten(int argc, VALUE *argv, VALUE ary) { - int mod = 0, level = -1; + int level = -1; VALUE result; if (rb_check_arity(argc, 0, 1) && !NIL_P(argv[0])) { @@ -5260,7 +5280,10 @@ rb_ary_flatten(int argc, VALUE *argv, VALUE ary) if (level == 0) return ary_make_shared_copy(ary); } - result = flatten(ary, level, &mod); + result = flatten(ary, level); + if (result == ary) { + result = ary_make_shared_copy(ary); + } OBJ_INFECT(result, ary); return result; diff --git a/benchmark/array_flatten.yml b/benchmark/array_flatten.yml new file mode 100644 index 0000000000..88ef544ba0 --- /dev/null +++ b/benchmark/array_flatten.yml @@ -0,0 +1,19 @@ +prelude: | + small_flat_ary = 5.times.to_a + large_flat_ary = 100.times.to_a + small_pairs_ary = [[1, 2]] * 5 + large_pairs_ary = [[1, 2]] * 100 + mostly_flat_ary = 100.times.to_a.push([101, 102]) + +benchmark: + small_flat_ary.flatten: small_flat_ary.flatten + small_flat_ary.flatten!: small_flat_ary.flatten! + large_flat_ary.flatten: large_flat_ary.flatten + large_flat_ary.flatten!: large_flat_ary.flatten! + small_pairs_ary.flatten: small_pairs_ary.flatten + small_pairs_ary.flatten!: small_pairs_ary.dup.flatten! + large_pairs_ary.flatten: large_pairs_ary.flatten + large_pairs_ary.flatten!: large_pairs_ary.dup.flatten! + mostly_flat_ary.flatten: mostly_flat_ary.flatten + mostly_flat_ary.flatten!: mostly_flat_ary.dup.flatten! +loop_count: 10000