1
0
Fork 0
mirror of https://github.com/ruby/ruby.git synced 2022-11-09 12:17:21 -05:00

Integer.sqrt [Feature #13219]

git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@57705 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
This commit is contained in:
nobu 2017-02-24 08:36:16 +00:00
parent 395ad27e72
commit bdd6b995f9
3 changed files with 147 additions and 3 deletions

View file

@ -419,14 +419,13 @@ static void
bary_small_rshift(BDIGIT *zds, const BDIGIT *xds, size_t n, int shift, BDIGIT higher_bdigit)
{
BDIGIT_DBL num = 0;
BDIGIT x;
assert(0 <= shift && shift < BITSPERDIG);
num = BIGUP(higher_bdigit);
while (n--) {
num = (num | xds[n]) >> shift;
x = xds[n];
BDIGIT x = xds[n];
num = (num | x) >> shift;
zds[n] = BIGLO(num);
num = BIGUP(x);
}
@ -6762,6 +6761,68 @@ rb_big_even_p(VALUE num)
return Qtrue;
}
unsigned long rb_ulong_isqrt(unsigned long);
#if SIZEOF_BDIGIT*2 > SIZEOF_LONG
BDIGIT rb_bdigit_dbl_isqrt(BDIGIT_DBL);
#else
# define rb_bdigit_dbl_isqrt(x) (BDIGIT)rb_ulong_isqrt(x)
#endif
VALUE
rb_big_isqrt(VALUE n)
{
BDIGIT *nds = BDIGITS(n);
size_t len = BIGNUM_LEN(n);
if (len <= 2) {
BDIGIT sq = rb_bdigit_dbl_isqrt(bary2bdigitdbl(nds, len));
#if SIZEOF_BDIGIT > SIZEOF_LONG
return ULL2NUM(sq);
#else
return ULONG2NUM(sq);
#endif
}
else {
int zbits = nlz(nds[len-1]);
int shift_bits = (len&1 ? BITSPERDIG/2 : BITSPERDIG) - (zbits+1)/2 + 1;
size_t tn = (len+1) / 2, xn = tn;
VALUE t, x = bignew_1(0, xn, 1); /* division may release the GVL */
BDIGIT *tds, *xds = BDIGITS(x);
/* x = (n >> (b/2+1)) */
if (shift_bits == BITSPERDIG) {
MEMCPY(xds, nds+tn, BDIGIT, xn);
}
else if (shift_bits > BITSPERDIG) {
bary_small_rshift(xds, nds+len-xn, xn, shift_bits-BITSPERDIG, 0);
}
else {
bary_small_rshift(xds, nds+len-xn-1, xn, shift_bits, nds[len-1]);
}
/* x |= (1 << (b-1)/2) */
xds[xn-1] |= (BDIGIT)1u <<
((len&1 ? 0 : BITSPERDIG/2) + (BITSPERDIG-zbits-1)/2);
/* t = n/x */
tn += BIGDIVREM_EXTRA_WORDS;
t = bignew_1(0, tn, 1);
tds = BDIGITS(t);
tn = BIGNUM_LEN(t);
while (bary_divmod_branch(tds, tn, NULL, 0, nds, len, xds, xn),
bary_cmp(tds, tn, xds, xn) < 0) {
int carry;
BARY_TRUNC(tds, tn);
carry = bary_add(xds, xn, xds, xn, tds, tn);
bary_small_rshift(xds, xds, xn, 1, carry);
tn = BIGNUM_LEN(t);
}
rb_big_realloc(t, 0);
rb_gc_force_recycle(t);
RBASIC_SET_CLASS_RAW(x, rb_cInteger);
return x;
}
}
/*
* Bignum objects hold integers outside the range of
* Fixnum. Bignum objects are created

View file

@ -5128,6 +5128,64 @@ int_truncate(int argc, VALUE* argv, VALUE num)
return rb_int_truncate(num, ndigits);
}
#define DEFINE_INT_SQRT(rettype, prefix, argtype) \
rettype \
prefix##_isqrt(argtype n) \
{ \
if (sizeof(n) * CHAR_BIT > DBL_MANT_DIG && \
n >= ((argtype)1UL << DBL_MANT_DIG)) { \
unsigned int b = bit_length(n); \
argtype t; \
rettype x = (rettype)(n >> (b/2+1)); \
x |= ((rettype)1LU << (b-1)/2); \
while ((t = n/x) < (argtype)x) x = (rettype)((x + t) >> 1); \
return x; \
} \
return (rettype)sqrt((double)n); \
}
DEFINE_INT_SQRT(unsigned long, rb_ulong, unsigned long)
#if SIZEOF_BDIGIT*2 > SIZEOF_LONG
DEFINE_INT_SQRT(BDIGIT, rb_bdigit_dbl, BDIGIT_DBL)
#endif
#define domain_error(msg) \
rb_raise(rb_eMathDomainError, "Numerical argument is out of domain - " #msg)
VALUE rb_big_isqrt(VALUE);
static VALUE
rb_int_s_isqrt(VALUE self, VALUE num)
{
unsigned long n, sq;
if (FIXNUM_P(num)) {
if (FIXNUM_NEGATIVE_P(num)) {
domain_error("isqrt");
}
n = FIX2ULONG(num);
sq = rb_ulong_isqrt(n);
return LONG2FIX(sq);
}
if (RB_TYPE_P(num, T_BIGNUM)) {
size_t biglen;
if (RBIGNUM_NEGATIVE_P(num)) {
domain_error("isqrt");
}
biglen = BIGNUM_LEN(num);
if (biglen == 0) return INT2FIX(0);
#if SIZEOF_BDIGIT <= SIZEOF_LONG
/* short-circuit */
if (biglen == 1) {
n = BIGNUM_DIGITS(num)[0];
sq = rb_ulong_isqrt(n);
return ULONG2NUM(sq);
}
#endif
return rb_big_isqrt(num);
}
return Qnil;
}
/*
* Document-class: ZeroDivisionError
*
@ -5281,6 +5339,7 @@ Init_Numeric(void)
rb_cInteger = rb_define_class("Integer", rb_cNumeric);
rb_undef_alloc_func(rb_cInteger);
rb_undef_method(CLASS_OF(rb_cInteger), "new");
rb_define_singleton_method(rb_cInteger, "sqrt", rb_int_s_isqrt, 1);
rb_define_method(rb_cInteger, "to_s", int_to_s, -1);
rb_define_alias(rb_cInteger, "inspect", "to_s");

View file

@ -464,4 +464,28 @@ class TestInteger < Test::Unit::TestCase
end
assert_equal([0, 1], 10.digits(o))
end
def test_square_root
assert_raise(Math::DomainError) {Integer.sqrt(-1)}
assert_equal(0, Integer.sqrt(0))
(1...4).each {|i| assert_equal(1, Integer.sqrt(i))}
(4...9).each {|i| assert_equal(2, Integer.sqrt(i))}
(9...16).each {|i| assert_equal(3, Integer.sqrt(i))}
(1..40).each do |i|
mesg = "10**#{i}"
s = Integer.sqrt(n = 10**i)
if i.even?
assert_equal(10**(i/2), Integer.sqrt(n), mesg)
else
assert_include((s**2)...(s+1)**2, n, mesg)
end
end
50.step(400, 10) do |i|
exact = 10**(i/2)
x = 10**i
assert_equal(exact, Integer.sqrt(x), "10**#{i}")
assert_equal(exact, Integer.sqrt(x+1), "10**#{i}+1")
assert_equal(exact-1, Integer.sqrt(x-1), "10**#{i}-1")
end
end
end