diff --git a/ChangeLog b/ChangeLog index da7cf89a52..03aa2e634e 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,14 @@ +Thu Sep 27 18:36:51 2012 Shugo Maeda + + * eval.c (rb_overlay_module, rb_mod_refine): accept a module as the + argument of Module#refine. + + * vm_method.c (search_method): if klass is an iclass, lookup the + original module of the iclass in omod in order to allow + refinements of modules. + + * test/ruby/test_refinement.rb: add tests for the above changes. + Thu Sep 27 18:12:20 2012 Aaron Patterson * ext/syslog/lib/syslog/logger.rb: add a formatter to the diff --git a/eval.c b/eval.c index 9630ca2621..4d27a8ec6b 100644 --- a/eval.c +++ b/eval.c @@ -1030,12 +1030,22 @@ rb_mod_prepend(int argc, VALUE *argv, VALUE module) return module; } +static +void check_class_or_module(VALUE obj) +{ + if (!RB_TYPE_P(obj, T_CLASS) && !RB_TYPE_P(obj, T_MODULE)) { + VALUE str = rb_inspect(obj); + rb_raise(rb_eTypeError, "%s is not a class/module", + StringValuePtr(str)); + } +} + void rb_overlay_module(NODE *cref, VALUE klass, VALUE module) { VALUE iclass, c, superclass = klass; - Check_Type(klass, T_CLASS); + check_class_or_module(klass); Check_Type(module, T_MODULE); if (NIL_P(cref->nd_omod)) { cref->nd_omod = rb_hash_new(); @@ -1184,7 +1194,7 @@ rb_mod_refine(VALUE module, VALUE klass) ID id_overlaid_modules, id_refined_class; VALUE overlaid_modules; - Check_Type(klass, T_CLASS); + check_class_or_module(klass); CONST_ID(id_overlaid_modules, "__overlaid_modules__"); overlaid_modules = rb_attr_get(module, id_overlaid_modules); if (NIL_P(overlaid_modules)) { diff --git a/test/ruby/test_refinement.rb b/test/ruby/test_refinement.rb index cc460894c4..ac6df929f0 100644 --- a/test/ruby/test_refinement.rb +++ b/test/ruby/test_refinement.rb @@ -324,5 +324,61 @@ class TestRefinement < Test::Unit::TestCase obj = c.new assert_equal([:c, :m1, :m2], m2.module_eval { obj.foo }) end -end + def test_refine_module_without_overriding + m1 = Module.new + c = Class.new { + include m1 + } + m2 = Module.new { + refine m1 do + def foo + :m2 + end + end + } + obj = c.new + assert_equal(:m2, m2.module_eval { obj.foo }) + end + + def test_refine_module_with_overriding + m1 = Module.new { + def foo + [:m1] + end + } + c = Class.new { + include m1 + } + m2 = Module.new { + refine m1 do + def foo + super << :m2 + end + end + } + obj = c.new + assert_equal([:m1, :m2], m2.module_eval { obj.foo }) + end + + def test_refine_neither_class_nor_module + assert_raise(TypeError) do + Module.new { + refine Object.new do + end + } + end + assert_raise(TypeError) do + Module.new { + refine 123 do + end + } + end + assert_raise(TypeError) do + Module.new { + refine "foo" do + end + } + end + end +end diff --git a/vm_method.c b/vm_method.c index 43b9ddb107..bf605e6f67 100644 --- a/vm_method.c +++ b/vm_method.c @@ -385,10 +385,20 @@ search_method(VALUE klass, ID id, VALUE omod, VALUE *defined_class_ptr) for (body = 0; klass; klass = RCLASS_SUPER(klass)) { st_table *m_tbl; - if (!NIL_P(omod) && klass != skipped_class && - !NIL_P(iclass = rb_hash_lookup(omod, klass))) { - skipped_class = klass; - klass = iclass; + if (!NIL_P(omod) && klass != skipped_class) { + VALUE c; + + if (BUILTIN_TYPE(klass) == T_ICLASS) { + c = RBASIC(klass)->klass; + } + else { + c = klass; + } + iclass = rb_hash_lookup(omod, c); + if (!NIL_P(iclass)) { + skipped_class = klass; + klass = iclass; + } } m_tbl = RCLASS_M_TBL(klass); if (!m_tbl) {