diff --git a/ChangeLog b/ChangeLog index 7c66c50472..e47ce5da96 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,9 @@ +Mon Oct 8 23:55:41 2012 Shugo Maeda + + * eval.c (rb_mod_refinements): new method Module#refinements. + + * test/ruby/test_refinement.rb: add new tests for the above changes. + Mon Oct 8 23:02:19 2012 Shugo Maeda * eval.c, gc.c, iseq.c, node.h, vm_insnhelper.c, vm_insnhelper.h, diff --git a/eval.c b/eval.c index c7e722323f..6dfa3e7399 100644 --- a/eval.c +++ b/eval.c @@ -1226,6 +1226,37 @@ rb_mod_refine(VALUE module, VALUE klass) return mod; } +static int +refinements_i(VALUE key, VALUE value, VALUE arg) +{ + rb_hash_aset(arg, key, value); + return ST_CONTINUE; +} + +/* + * call-seq: + * refinements -> hash + * + * Returns refinements in the receiver as a hash table, whose key is a + * refined class and whose value is a refinement module. + */ + +static VALUE +rb_mod_refinements(VALUE module) +{ + ID id_refinements; + VALUE refinements, result; + + CONST_ID(id_refinements, "__refinements__"); + refinements = rb_attr_get(module, id_refinements); + if (NIL_P(refinements)) { + return rb_hash_new(); + } + result = rb_hash_new(); + rb_hash_foreach(refinements, refinements_i, result); + return result; +} + void rb_obj_call_init(VALUE obj, int argc, VALUE *argv) { @@ -1524,6 +1555,7 @@ Init_eval(void) rb_define_private_method(rb_cModule, "prepend", rb_mod_prepend, -1); rb_define_private_method(rb_cModule, "using", rb_mod_using, 1); rb_define_private_method(rb_cModule, "refine", rb_mod_refine, 1); + rb_define_method(rb_cModule, "refinements", rb_mod_refinements, 0); rb_undef_method(rb_cClass, "module_function"); diff --git a/test/ruby/test_refinement.rb b/test/ruby/test_refinement.rb index 476c6442c7..6a1a1ccf80 100644 --- a/test/ruby/test_refinement.rb +++ b/test/ruby/test_refinement.rb @@ -455,4 +455,69 @@ class TestRefinement < Test::Unit::TestCase end end end + + def test_refinements_empty + m = Module.new + assert(m.refinements.empty?) + end + + def test_refinements_one + c = Class.new + c_ext = nil + m = Module.new { + refine c do + c_ext = self + end + } + assert_equal({c => c_ext}, m.refinements) + end + + def test_refinements_two + c1 = Class.new + c1_ext = nil + c2 = Class.new + c2_ext = nil + m = Module.new { + refine c1 do + c1_ext = self + end + + refine c2 do + c2_ext = self + end + } + assert_equal({c1 => c1_ext, c2 => c2_ext}, m.refinements) + end + + def test_refinements_duplicate_refine + c = Class.new + c_ext = nil + m = Module.new { + refine c do + c_ext = self + end + refine c do + end + } + assert_equal({c => c_ext}, m.refinements) + end + + def test_refinements_no_recursion + c1 = Class.new + c1_ext = nil + m1 = Module.new { + refine c1 do + c1_ext = self + end + } + c2 = Class.new + c2_ext = nil + m2 = Module.new { + using m1 + refine c2 do + c2_ext = self + end + } + assert_equal({c2 => c2_ext}, m2.refinements) + end end