mirror of
https://github.com/ruby/ruby.git
synced 2022-11-09 12:17:21 -05:00
[ruby/matrix] Optimize **
Avoiding recursive call would imply iterating bits starting from most significant, which is not easy to do efficiently. Any saving would be dwarfed by the multiplications anyways. [Feature #15233]
This commit is contained in:
parent
3b5b309b7b
commit
a83a51932d
Notes:
git
2020-12-05 14:57:22 +09:00
2 changed files with 44 additions and 15 deletions
|
@ -1233,26 +1233,49 @@ class Matrix
|
||||||
# # => 67 96
|
# # => 67 96
|
||||||
# # 48 99
|
# # 48 99
|
||||||
#
|
#
|
||||||
def **(other)
|
def **(exp)
|
||||||
case other
|
case exp
|
||||||
when Integer
|
when Integer
|
||||||
x = self
|
case
|
||||||
if other <= 0
|
when exp == 0
|
||||||
x = self.inverse
|
_make_sure_it_is_invertible = inverse
|
||||||
return self.class.identity(self.column_count) if other == 0
|
self.class.identity(column_count)
|
||||||
other = -other
|
when exp < 0
|
||||||
end
|
inverse.power_int(-exp)
|
||||||
z = nil
|
else
|
||||||
loop do
|
power_int(exp)
|
||||||
z = z ? z * x : x if other[0] == 1
|
|
||||||
return z if (other >>= 1).zero?
|
|
||||||
x *= x
|
|
||||||
end
|
end
|
||||||
when Numeric
|
when Numeric
|
||||||
v, d, v_inv = eigensystem
|
v, d, v_inv = eigensystem
|
||||||
v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** other}) * v_inv
|
v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** exp}) * v_inv
|
||||||
else
|
else
|
||||||
raise ErrOperationNotDefined, ["**", self.class, other.class]
|
raise ErrOperationNotDefined, ["**", self.class, exp.class]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
protected def power_int(exp)
|
||||||
|
# assumes `exp` is an Integer > 0
|
||||||
|
#
|
||||||
|
# Previous algorithm:
|
||||||
|
# build M**2, M**4 = (M**2)**2, M**8, ... and multiplying those you need
|
||||||
|
# e.g. M**0b1011 = M**11 = M * M**2 * M**8
|
||||||
|
# ^ ^
|
||||||
|
# (highlighted the 2 out of 5 multiplications involving `M * x`)
|
||||||
|
#
|
||||||
|
# Current algorithm has same number of multiplications but with lower exponents:
|
||||||
|
# M**11 = M * (M * M**4)**2
|
||||||
|
# ^ ^ ^
|
||||||
|
# (highlighted the 3 out of 5 multiplications involving `M * x`)
|
||||||
|
#
|
||||||
|
# This should be faster for all (non nil-potent) matrices.
|
||||||
|
case
|
||||||
|
when exp == 1
|
||||||
|
self
|
||||||
|
when exp.odd?
|
||||||
|
self * power_int(exp - 1)
|
||||||
|
else
|
||||||
|
sqrt = power_int(exp / 2)
|
||||||
|
sqrt * sqrt
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -448,6 +448,12 @@ class TestMatrix < Test::Unit::TestCase
|
||||||
assert_equal(Matrix[[67,96],[48,99]], Matrix[[7,6],[3,9]] ** 2)
|
assert_equal(Matrix[[67,96],[48,99]], Matrix[[7,6],[3,9]] ** 2)
|
||||||
assert_equal(Matrix.I(5), Matrix.I(5) ** -1)
|
assert_equal(Matrix.I(5), Matrix.I(5) ** -1)
|
||||||
assert_raise(Matrix::ErrOperationNotDefined) { Matrix.I(5) ** Object.new }
|
assert_raise(Matrix::ErrOperationNotDefined) { Matrix.I(5) ** Object.new }
|
||||||
|
|
||||||
|
m = Matrix[[0,2],[1,0]]
|
||||||
|
exp = 0b11101000
|
||||||
|
assert_equal(Matrix.scalar(2, 1 << (exp/2)), m ** exp)
|
||||||
|
exp = 0b11101001
|
||||||
|
assert_equal(Matrix[[0, 2 << (exp/2)], [1 << (exp/2), 0]], m ** exp)
|
||||||
end
|
end
|
||||||
|
|
||||||
def test_det
|
def test_det
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue