diff --git a/ext/openssl/ossl_pkey.c b/ext/openssl/ossl_pkey.c index df8b425a0f..d5f5a51ab6 100644 --- a/ext/openssl/ossl_pkey.c +++ b/ext/openssl/ossl_pkey.c @@ -732,6 +732,44 @@ ossl_pkey_public_to_pem(VALUE self) return ossl_pkey_export_spki(self, 0); } +/* + * call-seq: + * pkey.compare?(another_pkey) -> true | false + * + * Used primarily to check if an OpenSSL::X509::Certificate#public_key compares to its private key. + * + * == Example + * x509 = OpenSSL::X509::Certificate.new(pem_encoded_certificate) + * rsa_key = OpenSSL::PKey::RSA.new(pem_encoded_private_key) + * + * rsa_key.compare?(x509.public_key) => true | false + */ +static VALUE +ossl_pkey_compare(VALUE self, VALUE other) +{ + int ret; + EVP_PKEY *selfPKey; + EVP_PKEY *otherPKey; + + GetPKey(self, selfPKey); + GetPKey(other, otherPKey); + + /* Explicitly check the key type given EVP_PKEY_ASN1_METHOD(3) + * docs param_cmp could return any negative number. + */ + if (EVP_PKEY_id(selfPKey) != EVP_PKEY_id(otherPKey)) + ossl_raise(rb_eTypeError, "cannot match different PKey types"); + + ret = EVP_PKEY_cmp(selfPKey, otherPKey); + + if (ret == 0) + return Qfalse; + else if (ret == 1) + return Qtrue; + else + ossl_raise(ePKeyError, "EVP_PKEY_cmp"); +} + /* * call-seq: * pkey.sign(digest, data) -> String @@ -1031,6 +1069,7 @@ Init_ossl_pkey(void) rb_define_method(cPKey, "private_to_pem", ossl_pkey_private_to_pem, -1); rb_define_method(cPKey, "public_to_der", ossl_pkey_public_to_der, 0); rb_define_method(cPKey, "public_to_pem", ossl_pkey_public_to_pem, 0); + rb_define_method(cPKey, "compare?", ossl_pkey_compare, 1); rb_define_method(cPKey, "sign", ossl_pkey_sign, 2); rb_define_method(cPKey, "verify", ossl_pkey_verify, 3); diff --git a/test/openssl/test_pkey.rb b/test/openssl/test_pkey.rb index 5307fe5b08..0a516f98e8 100644 --- a/test/openssl/test_pkey.rb +++ b/test/openssl/test_pkey.rb @@ -151,4 +151,22 @@ class OpenSSL::TestPKey < OpenSSL::PKeyTestCase assert_equal bob_pem, bob.public_to_pem assert_equal [shared_secret].pack("H*"), alice.derive(bob) end + + def test_compare? + key1 = Fixtures.pkey("rsa1024") + key2 = Fixtures.pkey("rsa1024") + key3 = Fixtures.pkey("rsa2048") + key4 = Fixtures.pkey("dh-1") + + assert_equal(true, key1.compare?(key2)) + assert_equal(true, key1.public_key.compare?(key2)) + assert_equal(true, key2.compare?(key1)) + assert_equal(true, key2.public_key.compare?(key1)) + + assert_equal(false, key1.compare?(key3)) + + assert_raise(TypeError) do + key1.compare?(key4) + end + end end