diff --git a/class.c b/class.c index 12a67d16bc..68cfbfb257 100644 --- a/class.c +++ b/class.c @@ -351,7 +351,7 @@ copy_tables(VALUE clone, VALUE orig) } } -static void ensure_origin(VALUE klass); +static bool ensure_origin(VALUE klass); /* :nodoc: */ VALUE @@ -1014,27 +1014,31 @@ clear_module_cache_i(ID id, VALUE val, void *data) return ID_TABLE_CONTINUE; } +static bool +module_in_super_chain(const VALUE klass, VALUE module) +{ + struct rb_id_table *const klass_m_tbl = RCLASS_M_TBL(RCLASS_ORIGIN(klass)); + if (klass_m_tbl) { + while (module) { + if (klass_m_tbl == RCLASS_M_TBL(module)) + return true; + module = RCLASS_SUPER(module); + } + } + return false; +} + static int -include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super) +do_include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super, bool check_cyclic) { VALUE p, iclass, origin_stack = 0; int method_changed = 0, constant_changed = 0, add_subclass; long origin_len; VALUE klass_origin = RCLASS_ORIGIN(klass); - struct rb_id_table *const klass_m_tbl = RCLASS_M_TBL(klass_origin); VALUE original_klass = klass; - if (klass_m_tbl) { - VALUE original_module = module; - - while (module) { - if (klass_m_tbl == RCLASS_M_TBL(module)) - return -1; - module = RCLASS_SUPER(module); - } - - module = original_module; - } + if (check_cyclic && module_in_super_chain(klass, module)) + return -1; while (module) { int c_seen = FALSE; @@ -1129,6 +1133,12 @@ include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super) return method_changed; } +static int +include_modules_at(const VALUE klass, VALUE c, VALUE module, int search_super) +{ + return do_include_modules_at(klass, c, module, search_super, true); +} + static enum rb_id_table_iterator_result move_refined_method(ID key, VALUE value, void *data) { @@ -1169,7 +1179,7 @@ cache_clear_refined_method(ID key, VALUE value, void *data) return ID_TABLE_CONTINUE; } -static void +static bool ensure_origin(VALUE klass) { VALUE origin = RCLASS_ORIGIN(klass); @@ -1182,20 +1192,24 @@ ensure_origin(VALUE klass) RCLASS_M_TBL_INIT(klass); rb_id_table_foreach(RCLASS_M_TBL(origin), cache_clear_refined_method, (void *)klass); rb_id_table_foreach(RCLASS_M_TBL(origin), move_refined_method, (void *)klass); + return true; } + return false; } void rb_prepend_module(VALUE klass, VALUE module) { - int changed = 0; - bool klass_had_no_origin = RCLASS_ORIGIN(klass) == klass; + int changed; + bool klass_had_no_origin; ensure_includable(klass, module); - ensure_origin(klass); - changed = include_modules_at(klass, klass, module, FALSE); - if (changed < 0) - rb_raise(rb_eArgError, "cyclic prepend detected"); + if (module_in_super_chain(klass, module)) + rb_raise(rb_eArgError, "cyclic prepend detected"); + + klass_had_no_origin = ensure_origin(klass); + changed = do_include_modules_at(klass, klass, module, FALSE, false); + RUBY_ASSERT(changed >= 0); // already checked for cyclic prepend above if (changed) { rb_vm_check_redefinition_by_prepend(klass); } diff --git a/test/ruby/test_module.rb b/test/ruby/test_module.rb index 0a5597fd6c..e5152b1012 100644 --- a/test/ruby/test_module.rb +++ b/test/ruby/test_module.rb @@ -485,6 +485,19 @@ class TestModule < Test::Unit::TestCase assert_equal([m], m.ancestors) end + def test_bug17590 + m = Module.new + c = Class.new + c.prepend(m) + c.include(m) + m.prepend(m) rescue nil + m2 = Module.new + m2.prepend(m) + c.include(m2) + + assert_equal([m, c, m2] + Object.ancestors, c.ancestors) + end + def test_prepend_works_with_duped_classes m = Module.new a = Class.new do