mirror of
https://github.com/ruby/ruby.git
synced 2022-11-09 12:17:21 -05:00
ext/openssl: refactor OpenSSL::PKey::EC::Point#mul
* ext/openssl/ossl_pkey_ec.c (ossl_ec_point_mul): Validate the arguments before passing to EC_POINT(s)_mul(). Add description of this method. [ruby-core:65152] [Bug #10268] * test/openssl/test_pkey_ec.rb (test_ec_point_mul): Test that OpenSSL::PKey::EC::Point#mul works. git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@55048 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
This commit is contained in:
parent
b43fd8e080
commit
01801f2afd
3 changed files with 99 additions and 61 deletions
|
@ -1,3 +1,12 @@
|
||||||
|
Wed May 18 11:53:49 2016 Kazuki Yamaguchi <k@rhe.jp>
|
||||||
|
|
||||||
|
* ext/openssl/ossl_pkey_ec.c (ossl_ec_point_mul): Validate the
|
||||||
|
arguments before passing to EC_POINT(s)_mul(). Add description of this
|
||||||
|
method. [ruby-core:65152] [Bug #10268]
|
||||||
|
|
||||||
|
* test/openssl/test_pkey_ec.rb (test_ec_point_mul): Test that
|
||||||
|
OpenSSL::PKey::EC::Point#mul works.
|
||||||
|
|
||||||
Wed May 18 11:19:59 2016 Kazuki Yamaguchi <k@rhe.jp>
|
Wed May 18 11:19:59 2016 Kazuki Yamaguchi <k@rhe.jp>
|
||||||
|
|
||||||
* ext/openssl/ossl_bn.c (try_convert_to_bnptr): Extracted from
|
* ext/openssl/ossl_bn.c (try_convert_to_bnptr): Extracted from
|
||||||
|
|
|
@ -1500,74 +1500,84 @@ static VALUE ossl_ec_point_to_bn(VALUE self)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* call-seq:
|
* call-seq:
|
||||||
* point.mul(bn) => point
|
* point.mul(bn1 [, bn2]) => point
|
||||||
* point.mul(bn, bn) => point
|
* point.mul(bns, points [, bn2]) => point
|
||||||
* point.mul([bn], [point]) => point
|
*
|
||||||
* point.mul([bn], [point], bn) => point
|
* Performs elliptic curve point multiplication.
|
||||||
|
*
|
||||||
|
* The first form calculates <tt>bn1 * point + bn2 * G</tt>, where +G+ is the
|
||||||
|
* generator of the group of +point+. +bn2+ may be ommitted, and in that case,
|
||||||
|
* the result is just <tt>bn1 * point</tt>.
|
||||||
|
*
|
||||||
|
* The second form calculates <tt>bns[0] * point + bns[1] * points[0] + ...
|
||||||
|
* + bns[-1] * points[-1] + bn2 * G</tt>. +bn2+ may be ommitted. +bns+ must be
|
||||||
|
* an array of OpenSSL::BN. +points+ must be an array of
|
||||||
|
* OpenSSL::PKey::EC::Point. Please note that <tt>points[0]</tt> is not
|
||||||
|
* multiplied by <tt>bns[0]</tt>, but <tt>bns[1]</tt>.
|
||||||
*/
|
*/
|
||||||
static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self)
|
static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self)
|
||||||
{
|
{
|
||||||
EC_POINT *point1, *point2;
|
EC_POINT *point_self, *point_result;
|
||||||
const EC_GROUP *group;
|
const EC_GROUP *group;
|
||||||
VALUE group_v = rb_iv_get(self, "@group");
|
VALUE group_v = rb_iv_get(self, "@group");
|
||||||
VALUE bn_v1, bn_v2, r, points_v;
|
VALUE arg1, arg2, arg3, result;
|
||||||
BIGNUM *bn1 = NULL, *bn2 = NULL;
|
const BIGNUM *bn_g = NULL;
|
||||||
|
|
||||||
Require_EC_POINT(self, point1);
|
Require_EC_POINT(self, point_self);
|
||||||
SafeRequire_EC_GROUP(group_v, group);
|
SafeRequire_EC_GROUP(group_v, group);
|
||||||
|
|
||||||
r = rb_obj_alloc(cEC_POINT);
|
result = rb_obj_alloc(cEC_POINT);
|
||||||
ossl_ec_point_initialize(1, &group_v, r);
|
ossl_ec_point_initialize(1, &group_v, result);
|
||||||
Require_EC_POINT(r, point2);
|
Require_EC_POINT(result, point_result);
|
||||||
|
|
||||||
argc = rb_scan_args(argc, argv, "12", &bn_v1, &points_v, &bn_v2);
|
rb_scan_args(argc, argv, "12", &arg1, &arg2, &arg3);
|
||||||
|
if (rb_obj_is_kind_of(arg1, cBN)) {
|
||||||
|
BIGNUM *bn = GetBNPtr(arg1);
|
||||||
|
if (argc >= 2)
|
||||||
|
bn_g = GetBNPtr(arg2);
|
||||||
|
|
||||||
if (rb_obj_is_kind_of(bn_v1, cBN)) {
|
if (EC_POINT_mul(group, point_result, bn_g, point_self, bn, ossl_bn_ctx) != 1)
|
||||||
bn1 = GetBNPtr(bn_v1);
|
ossl_raise(eEC_POINT, NULL);
|
||||||
if (argc >= 2) {
|
|
||||||
bn2 = GetBNPtr(points_v);
|
|
||||||
}
|
|
||||||
if (EC_POINT_mul(group, point2, bn2, point1, bn1, ossl_bn_ctx) != 1)
|
|
||||||
ossl_raise(eEC_POINT, "Multiplication failed");
|
|
||||||
} else {
|
} else {
|
||||||
size_t i, points_len, bignums_len;
|
/*
|
||||||
|
* bignums | arg1[0] | arg1[1] | arg1[2] | ...
|
||||||
|
* points | self | arg2[0] | arg2[1] | ...
|
||||||
|
*/
|
||||||
|
int i, num;
|
||||||
|
VALUE tmp_p, tmp_b;
|
||||||
const EC_POINT **points;
|
const EC_POINT **points;
|
||||||
const BIGNUM **bignums;
|
const BIGNUM **bignums;
|
||||||
|
|
||||||
Check_Type(bn_v1, T_ARRAY);
|
if (!rb_obj_is_kind_of(arg1, rb_cArray) ||
|
||||||
bignums_len = RARRAY_LEN(bn_v1);
|
!rb_obj_is_kind_of(arg2, rb_cArray))
|
||||||
bignums = (const BIGNUM **)OPENSSL_malloc(bignums_len * (int)sizeof(BIGNUM *));
|
ossl_raise(rb_eTypeError, "points must be array");
|
||||||
|
if (RARRAY_LEN(arg1) != RARRAY_LEN(arg2) + 1) /* arg2 must be 1 larger */
|
||||||
|
ossl_raise(rb_eArgError, "bns must be 1 longer than points; see the documentation");
|
||||||
|
|
||||||
for (i = 0; i < bignums_len; ++i) {
|
num = RARRAY_LEN(arg1);
|
||||||
bignums[i] = GetBNPtr(rb_ary_entry(bn_v1, i));
|
bignums = ALLOCV_N(const BIGNUM *, tmp_b, num);
|
||||||
|
for (i = 0; i < num; i++)
|
||||||
|
bignums[i] = GetBNPtr(RARRAY_AREF(arg1, i));
|
||||||
|
|
||||||
|
points = ALLOCV_N(const EC_POINT *, tmp_p, num);
|
||||||
|
points[0] = point_self; /* self */
|
||||||
|
for (i = 0; i < num - 1; i++)
|
||||||
|
SafeRequire_EC_POINT(RARRAY_AREF(arg2, i), points[i + 1]);
|
||||||
|
|
||||||
|
if (argc >= 3)
|
||||||
|
bn_g = GetBNPtr(arg3);
|
||||||
|
|
||||||
|
if (EC_POINTs_mul(group, point_result, bn_g, num, points, bignums, ossl_bn_ctx) != 1) {
|
||||||
|
ALLOCV_END(tmp_b);
|
||||||
|
ALLOCV_END(tmp_p);
|
||||||
|
ossl_raise(eEC_POINT, NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!rb_obj_is_kind_of(points_v, rb_cArray)) {
|
ALLOCV_END(tmp_b);
|
||||||
OPENSSL_free((void *)bignums);
|
ALLOCV_END(tmp_p);
|
||||||
rb_raise(rb_eTypeError, "Argument2 must be an array");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rb_ary_unshift(points_v, self);
|
return result;
|
||||||
points_len = RARRAY_LEN(points_v);
|
|
||||||
points = (const EC_POINT **)OPENSSL_malloc(points_len * (int)sizeof(EC_POINT *));
|
|
||||||
|
|
||||||
for (i = 0; i < points_len; ++i) {
|
|
||||||
Get_EC_POINT(rb_ary_entry(points_v, i), points[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (argc >= 3) {
|
|
||||||
bn2 = GetBNPtr(bn_v2);
|
|
||||||
}
|
|
||||||
if (EC_POINTs_mul(group, point2, bn2, points_len, points, bignums, ossl_bn_ctx) != 1) {
|
|
||||||
OPENSSL_free((void *)bignums);
|
|
||||||
OPENSSL_free((void *)points);
|
|
||||||
ossl_raise(eEC_POINT, "Multiplication failed");
|
|
||||||
}
|
|
||||||
OPENSSL_free((void *)bignums);
|
|
||||||
OPENSSL_free((void *)points);
|
|
||||||
}
|
|
||||||
|
|
||||||
return r;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void no_copy(VALUE klass)
|
static void no_copy(VALUE klass)
|
||||||
|
|
|
@ -190,19 +190,38 @@ class OpenSSL::TestEC < Test::Unit::TestCase
|
||||||
end
|
end
|
||||||
|
|
||||||
def test_ec_point_mul
|
def test_ec_point_mul
|
||||||
ec = OpenSSL::TestUtils::TEST_KEY_EC_P256V1
|
# y^2 = x^3 + 2x + 2 over F_17
|
||||||
p1 = ec.public_key
|
# generator is (5, 1)
|
||||||
bn1 = OpenSSL::BN.new('10')
|
group = OpenSSL::PKey::EC::Group.new(:GFp, 17, 2, 2)
|
||||||
bn2 = OpenSSL::BN.new('20')
|
gen = OpenSSL::PKey::EC::Point.new(group, OpenSSL::BN.new("040501", 16))
|
||||||
|
group.set_generator(gen, 0, 0)
|
||||||
|
|
||||||
p2 = p1.mul(bn1)
|
# 3 * (6, 3) = (16, 13)
|
||||||
assert(p1.group == p2.group)
|
point_a = OpenSSL::PKey::EC::Point.new(group, OpenSSL::BN.new("040603", 16))
|
||||||
p2 = p1.mul(bn1, bn2)
|
result_a1 = point_a.mul(3.to_bn)
|
||||||
assert(p1.group == p2.group)
|
assert_equal("04100D", result_a1.to_bn.to_s(16))
|
||||||
p2 = p1.mul([bn1, bn2], [p1])
|
# 3 * (6, 3) + 3 * (5, 1) = (7, 6)
|
||||||
assert(p1.group == p2.group)
|
result_a2 = point_a.mul(3.to_bn, 3.to_bn)
|
||||||
p2 = p1.mul([bn1, bn2], [p1], bn2)
|
assert_equal("040706", result_a2.to_bn.to_s(16))
|
||||||
assert(p1.group == p2.group)
|
# 3 * point_a = 3 * (6, 3) = (16, 13)
|
||||||
|
result_b1 = point_a.mul([3.to_bn], [])
|
||||||
|
assert_equal("04100D", result_b1.to_bn.to_s(16))
|
||||||
|
# 3 * point_a + 2 * point_a = 3 * (6, 3) + 2 * (6, 3) = (7, 11)
|
||||||
|
result_b1 = point_a.mul([3.to_bn, 2.to_bn], [point_a])
|
||||||
|
assert_equal("04070B", result_b1.to_bn.to_s(16))
|
||||||
|
# 3 * point_a + 5 * point_a.group.generator = 3 * (6, 3) + 5 * (5, 1) = (13, 10)
|
||||||
|
result_b1 = point_a.mul([3.to_bn], [], 5)
|
||||||
|
assert_equal("040D0A", result_b1.to_bn.to_s(16))
|
||||||
|
|
||||||
|
p256_key = OpenSSL::TestUtils::TEST_KEY_EC_P256V1
|
||||||
|
p256_g = p256_key.group
|
||||||
|
assert_equal(p256_key.public_key, p256_g.generator.mul(p256_key.private_key))
|
||||||
|
|
||||||
|
# invalid argument
|
||||||
|
assert_raise(TypeError) { point_a.mul(nil) }
|
||||||
|
assert_raise(ArgumentError) { point_a.mul([1.to_bn], [point_a]) }
|
||||||
|
assert_raise(TypeError) { point_a.mul([1.to_bn], nil) }
|
||||||
|
assert_raise(TypeError) { point_a.mul([nil], []) }
|
||||||
end
|
end
|
||||||
|
|
||||||
# test Group: asn1_flag, point_conversion
|
# test Group: asn1_flag, point_conversion
|
||||||
|
|
Loading…
Reference in a new issue