diff --git a/ext/date/date_core.c b/ext/date/date_core.c index 3b18bbd5b1..c250633426 100644 --- a/ext/date/date_core.c +++ b/ext/date/date_core.c @@ -51,18 +51,18 @@ static double positive_inf, negative_inf; #define f_add3(x,y,z) f_add(f_add(x, y), z) #define f_sub3(x,y,z) f_sub(f_sub(x, y), z) -inline static VALUE +inline static int f_cmp(VALUE x, VALUE y) { if (FIXNUM_P(x) && FIXNUM_P(y)) { long c = FIX2LONG(x) - FIX2LONG(y); if (c > 0) - c = 1; + return 1; else if (c < 0) - c = -1; - return INT2FIX(c); + return -1; + return 0; } - return rb_funcall(x, id_cmp, 1, y); + return rb_cmpint(rb_funcallv(x, id_cmp, 1, &y), x, y); } inline static VALUE @@ -6154,6 +6154,7 @@ static VALUE d_lite_step(int argc, VALUE *argv, VALUE self) { VALUE limit, step, date; + int c; rb_scan_args(argc, argv, "11", &limit, &step); @@ -6168,25 +6169,22 @@ d_lite_step(int argc, VALUE *argv, VALUE self) RETURN_ENUMERATOR(self, argc, argv); date = self; - switch (FIX2INT(f_cmp(step, INT2FIX(0)))) { - case -1: + c = f_cmp(step, INT2FIX(0)); + if (c < 0) { while (FIX2INT(d_lite_cmp(date, limit)) >= 0) { rb_yield(date); date = d_lite_plus(date, step); } - break; - case 0: + } + else if (c == 0) { while (1) rb_yield(date); - break; - case 1: + } + else /* if (c > 0) */ { while (FIX2INT(d_lite_cmp(date, limit)) <= 0) { rb_yield(date); date = d_lite_plus(date, step); } - break; - default: - abort(); } return self; } @@ -6241,9 +6239,9 @@ cmp_gen(VALUE self, VALUE other) get_d1(self); if (k_numeric_p(other)) - return f_cmp(m_ajd(dat), other); + return INT2FIX(f_cmp(m_ajd(dat), other)); else if (k_date_p(other)) - return f_cmp(m_ajd(dat), f_ajd(other)); + return INT2FIX(f_cmp(m_ajd(dat), f_ajd(other))); return rb_num_coerce_cmp(self, other, rb_intern("<=>")); } diff --git a/test/date/test_date_arith.rb b/test/date/test_date_arith.rb index 96622ba065..d0d27d72f7 100644 --- a/test/date/test_date_arith.rb +++ b/test/date/test_date_arith.rb @@ -262,4 +262,17 @@ class TestDateArith < Test::Unit::TestCase assert_equal(8, e.to_a.size) end + def test_step__compare + o = Object.new + def o.<=>(*);end + assert_raise(ArgumentError) { + Date.new(2000, 1, 1).step(3, o).to_a + } + + o = Object.new + def o.<=>(*);2;end + a = [] + Date.new(2000, 1, 1).step(3, o) {|d| a << d} + assert_empty(a) + end end