diff --git a/ChangeLog b/ChangeLog index 8ea39b4c7f..1e81ce908e 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,9 @@ +Thu Sep 8 10:08:35 2016 Kazuki Yamaguchi + + * {ext,test}/openssl: Import Ruby/OpenSSL 2.0.0.beta.2. The full commit + history since v2.0.0.beta.1 can be found at: + https://github.com/ruby/openssl/compare/v2.0.0.beta.1...v2.0.0.beta.2 + Thu Sep 8 07:23:34 2016 SHIBATA Hiroshi * lib/rdoc/*, test/rdoc/*: Update rdoc-5.0.0.beta2 diff --git a/ext/openssl/extconf.rb b/ext/openssl/extconf.rb index bf78428619..20c67c6b50 100644 --- a/ext/openssl/extconf.rb +++ b/ext/openssl/extconf.rb @@ -152,7 +152,5 @@ have_func("X509_get0_notBefore") Logging::message "=== Checking done. ===\n" create_header -create_makefile("openssl") {|conf| - conf << "THREAD_MODEL = #{CONFIG["THREAD_MODEL"]}\n" -} +create_makefile("openssl") Logging::message "Done.\n" diff --git a/ext/openssl/lib/openssl/ssl.rb b/ext/openssl/lib/openssl/ssl.rb index 519ea11a54..190f504276 100644 --- a/ext/openssl/lib/openssl/ssl.rb +++ b/ext/openssl/lib/openssl/ssl.rb @@ -73,16 +73,6 @@ module OpenSSL DEFAULT_CERT_STORE.set_default_paths DEFAULT_CERT_STORE.flags = OpenSSL::X509::V_FLAG_CRL_CHECK_ALL - # :nodoc: - INIT_VARS = ["cert", "key", "client_ca", "ca_file", "ca_path", - "timeout", "verify_mode", "verify_depth", "renegotiation_cb", - "verify_callback", "cert_store", "extra_chain_cert", - "client_cert_cb", "session_id_context", "tmp_dh_callback", - "session_get_cb", "session_new_cb", "session_remove_cb", - "tmp_ecdh_callback", "servername_cb", "npn_protocols", - "alpn_protocols", "alpn_select_cb", - "npn_select_cb", "verify_hostname"].map { |x| "@#{x}" } - # A callback invoked when DH parameters are required. # # The callback is invoked with the Session for the key exchange, an @@ -110,10 +100,8 @@ module OpenSSL # # You can get a list of valid methods with OpenSSL::SSL::SSLContext::METHODS def initialize(version = nil) - INIT_VARS.each { |v| instance_variable_set v, nil } - self.options = self.options | OpenSSL::SSL::OP_ALL - return unless version - self.ssl_version = version + self.options |= OpenSSL::SSL::OP_ALL + self.ssl_version = version if version end ## diff --git a/ext/openssl/openssl.gemspec b/ext/openssl/openssl.gemspec index 4254a9fd6b..48191fa0e9 100644 --- a/ext/openssl/openssl.gemspec +++ b/ext/openssl/openssl.gemspec @@ -1,19 +1,19 @@ # -*- encoding: utf-8 -*- -# stub: openssl 2.0.0.beta.1 ruby lib +# stub: openssl 2.0.0.beta.2 ruby lib # stub: ext/openssl/extconf.rb Gem::Specification.new do |s| s.name = "openssl".freeze - s.version = "2.0.0.beta.1" + s.version = "2.0.0.beta.2" s.required_rubygems_version = Gem::Requirement.new("> 1.3.1".freeze) if s.respond_to? :required_rubygems_version= s.require_paths = ["lib".freeze] s.authors = ["Martin Bosslet".freeze, "SHIBATA Hiroshi".freeze, "Zachary Scott".freeze, "Kazuki Yamaguchi".freeze] - s.date = "2016-08-29" + s.date = "2016-09-08" s.description = "It wraps the OpenSSL library.".freeze s.email = ["ruby-core@ruby-lang.org".freeze] s.extensions = ["ext/openssl/extconf.rb".freeze] - s.extra_rdoc_files = ["CONTRIBUTING.md".freeze, "History.md".freeze, "README.md".freeze] + s.extra_rdoc_files = ["CONTRIBUTING.md".freeze, "README.md".freeze, "History.md".freeze] s.files = ["BSDL".freeze, "CONTRIBUTING.md".freeze, "History.md".freeze, "LICENSE.txt".freeze, "README.md".freeze, "ext/openssl/deprecation.rb".freeze, "ext/openssl/extconf.rb".freeze, "ext/openssl/openssl_missing.c".freeze, "ext/openssl/openssl_missing.h".freeze, "ext/openssl/ossl.c".freeze, "ext/openssl/ossl.h".freeze, "ext/openssl/ossl_asn1.c".freeze, "ext/openssl/ossl_asn1.h".freeze, "ext/openssl/ossl_bio.c".freeze, "ext/openssl/ossl_bio.h".freeze, "ext/openssl/ossl_bn.c".freeze, "ext/openssl/ossl_bn.h".freeze, "ext/openssl/ossl_cipher.c".freeze, "ext/openssl/ossl_cipher.h".freeze, "ext/openssl/ossl_config.c".freeze, "ext/openssl/ossl_config.h".freeze, "ext/openssl/ossl_digest.c".freeze, "ext/openssl/ossl_digest.h".freeze, "ext/openssl/ossl_engine.c".freeze, "ext/openssl/ossl_engine.h".freeze, "ext/openssl/ossl_hmac.c".freeze, "ext/openssl/ossl_hmac.h".freeze, "ext/openssl/ossl_ns_spki.c".freeze, "ext/openssl/ossl_ns_spki.h".freeze, "ext/openssl/ossl_ocsp.c".freeze, "ext/openssl/ossl_ocsp.h".freeze, "ext/openssl/ossl_pkcs12.c".freeze, "ext/openssl/ossl_pkcs12.h".freeze, "ext/openssl/ossl_pkcs5.c".freeze, "ext/openssl/ossl_pkcs5.h".freeze, "ext/openssl/ossl_pkcs7.c".freeze, "ext/openssl/ossl_pkcs7.h".freeze, "ext/openssl/ossl_pkey.c".freeze, "ext/openssl/ossl_pkey.h".freeze, "ext/openssl/ossl_pkey_dh.c".freeze, "ext/openssl/ossl_pkey_dsa.c".freeze, "ext/openssl/ossl_pkey_ec.c".freeze, "ext/openssl/ossl_pkey_rsa.c".freeze, "ext/openssl/ossl_rand.c".freeze, "ext/openssl/ossl_rand.h".freeze, "ext/openssl/ossl_ssl.c".freeze, "ext/openssl/ossl_ssl.h".freeze, "ext/openssl/ossl_ssl_session.c".freeze, "ext/openssl/ossl_version.h".freeze, "ext/openssl/ossl_x509.c".freeze, "ext/openssl/ossl_x509.h".freeze, "ext/openssl/ossl_x509attr.c".freeze, "ext/openssl/ossl_x509cert.c".freeze, "ext/openssl/ossl_x509crl.c".freeze, "ext/openssl/ossl_x509ext.c".freeze, "ext/openssl/ossl_x509name.c".freeze, "ext/openssl/ossl_x509req.c".freeze, "ext/openssl/ossl_x509revoked.c".freeze, "ext/openssl/ossl_x509store.c".freeze, "ext/openssl/ruby_missing.h".freeze, "lib/openssl.rb".freeze, "lib/openssl/bn.rb".freeze, "lib/openssl/buffering.rb".freeze, "lib/openssl/cipher.rb".freeze, "lib/openssl/config.rb".freeze, "lib/openssl/digest.rb".freeze, "lib/openssl/pkey.rb".freeze, "lib/openssl/ssl.rb".freeze, "lib/openssl/x509.rb".freeze] s.homepage = "https://www.ruby-lang.org/".freeze s.licenses = ["Ruby".freeze] diff --git a/ext/openssl/ossl_pkcs12.c b/ext/openssl/ossl_pkcs12.c index a7daad208e..0b9c7816b5 100644 --- a/ext/openssl/ossl_pkcs12.c +++ b/ext/openssl/ossl_pkcs12.c @@ -190,15 +190,17 @@ ossl_pkcs12_initialize(int argc, VALUE *argv, VALUE self) if(!PKCS12_parse(pkcs, passphrase, &key, &x509, &x509s)) ossl_raise(ePKCS12Error, "PKCS12_parse"); ERR_pop_to_mark(); - pkey = rb_protect((VALUE (*)(VALUE))ossl_pkey_new, (VALUE)key, - &st); /* NO DUP */ - if(st) goto err; - cert = rb_protect((VALUE (*)(VALUE))ossl_x509_new, (VALUE)x509, &st); - if(st) goto err; - if(x509s){ - ca = - rb_protect((VALUE (*)(VALUE))ossl_x509_sk2ary, (VALUE)x509s, &st); - if(st) goto err; + if (key) { + pkey = rb_protect((VALUE (*)(VALUE))ossl_pkey_new, (VALUE)key, &st); + if (st) goto err; + } + if (x509) { + cert = rb_protect((VALUE (*)(VALUE))ossl_x509_new, (VALUE)x509, &st); + if (st) goto err; + } + if (x509s) { + ca = rb_protect((VALUE (*)(VALUE))ossl_x509_sk2ary, (VALUE)x509s, &st); + if (st) goto err; } err: diff --git a/ext/openssl/ossl_pkey_ec.c b/ext/openssl/ossl_pkey_ec.c index 9a884c7ef5..f0a31231b4 100644 --- a/ext/openssl/ossl_pkey_ec.c +++ b/ext/openssl/ossl_pkey_ec.c @@ -6,17 +6,6 @@ #if !defined(OPENSSL_NO_EC) && (OPENSSL_VERSION_NUMBER >= 0x0090802fL) -typedef struct { - EC_GROUP *group; - int dont_free; -} ossl_ec_group; - -typedef struct { - EC_POINT *point; - int dont_free; -} ossl_ec_point; - - #define EXPORT_PEM 0 #define EXPORT_DER 1 @@ -29,66 +18,39 @@ static const rb_data_type_t ossl_ec_point_type; ossl_raise(rb_eRuntimeError, "THIS IS NOT A EC PKEY!"); \ } \ } while (0) - -#define SafeGet_ec_group(obj, group) do { \ - OSSL_Check_Kind((obj), cEC_GROUP); \ - TypedData_Get_Struct((obj), ossl_ec_group, &ossl_ec_group_type, (group)); \ -} while(0) - -#define Get_EC_KEY(obj, key) do { \ - EVP_PKEY *pkey; \ - GetPKeyEC((obj), pkey); \ - (key) = EVP_PKEY_get0_EC_KEY(pkey); \ -} while(0) - -#define Require_EC_KEY(obj, key) do { \ - Get_EC_KEY((obj), (key)); \ - if ((key) == NULL) \ - ossl_raise(eECError, "EC_KEY is not initialized"); \ -} while(0) - -#define SafeRequire_EC_KEY(obj, key) do { \ - OSSL_Check_Kind((obj), cEC); \ - Require_EC_KEY((obj), (key)); \ +#define GetEC(obj, key) do { \ + EVP_PKEY *_pkey; \ + GetPKeyEC(obj, _pkey); \ + (key) = EVP_PKEY_get0_EC_KEY(_pkey); \ +} while (0) +#define SafeGetEC(obj, key) do { \ + OSSL_Check_Kind(obj, cEC); \ + GetEC(obj, key); \ } while (0) -#define Get_EC_GROUP(obj, g) do { \ - ossl_ec_group *ec_group; \ - TypedData_Get_Struct((obj), ossl_ec_group, &ossl_ec_group_type, ec_group); \ - if (ec_group == NULL) \ - ossl_raise(eEC_GROUP, "missing ossl_ec_group structure"); \ - (g) = ec_group->group; \ -} while(0) - -#define Require_EC_GROUP(obj, group) do { \ - Get_EC_GROUP((obj), (group)); \ +#define GetECGroup(obj, group) do { \ + TypedData_Get_Struct(obj, EC_GROUP, &ossl_ec_group_type, group); \ if ((group) == NULL) \ - ossl_raise(eEC_GROUP, "EC_GROUP is not initialized"); \ -} while(0) - -#define SafeRequire_EC_GROUP(obj, group) do { \ + ossl_raise(eEC_GROUP, "EC_GROUP is not initialized"); \ +} while (0) +#define SafeGetECGroup(obj, group) do { \ OSSL_Check_Kind((obj), cEC_GROUP); \ - Require_EC_GROUP((obj), (group)); \ -} while(0) + GetECGroup(obj, group); \ +} while (0) -#define Get_EC_POINT(obj, p) do { \ - ossl_ec_point *ec_point; \ - TypedData_Get_Struct((obj), ossl_ec_point, &ossl_ec_point_type, ec_point); \ - if (ec_point == NULL) \ - ossl_raise(eEC_POINT, "missing ossl_ec_point structure"); \ - (p) = ec_point->point; \ -} while(0) - -#define Require_EC_POINT(obj, point) do { \ - Get_EC_POINT((obj), (point)); \ +#define GetECPoint(obj, point) do { \ + TypedData_Get_Struct(obj, EC_POINT, &ossl_ec_point_type, point); \ if ((point) == NULL) \ - ossl_raise(eEC_POINT, "EC_POINT is not initialized"); \ -} while(0) - -#define SafeRequire_EC_POINT(obj, point) do { \ + ossl_raise(eEC_POINT, "EC_POINT is not initialized"); \ +} while (0) +#define SafeGetECPoint(obj, point) do { \ OSSL_Check_Kind((obj), cEC_POINT); \ - Require_EC_POINT((obj), (point)); \ + GetECPoint(obj, point); \ } while(0) +#define GetECPointGroup(obj, group) do { \ + VALUE _group = rb_attr_get(obj, id_i_group); \ + SafeGetECGroup(_group, group); \ +} while (0) VALUE cEC; VALUE eECError; @@ -108,7 +70,10 @@ static ID ID_uncompressed; static ID ID_compressed; static ID ID_hybrid; -static ID id_i_group, id_i_key; +static ID id_i_group; + +static VALUE ec_group_new(const EC_GROUP *group); +static VALUE ec_point_new(const EC_POINT *point, const EC_GROUP *group); static VALUE ec_instance(VALUE klass, EC_KEY *ec) { @@ -163,8 +128,7 @@ ec_key_new_from_group(VALUE arg) if (rb_obj_is_kind_of(arg, cEC_GROUP)) { EC_GROUP *group; - SafeRequire_EC_GROUP(arg, group); - + SafeGetECGroup(arg, group); if (!(ec = EC_KEY_new())) ossl_raise(eECError, NULL); @@ -244,7 +208,7 @@ static VALUE ossl_ec_key_initialize(int argc, VALUE *argv, VALUE self) } else if (rb_obj_is_kind_of(arg, cEC)) { EC_KEY *other_ec = NULL; - SafeRequire_EC_KEY(arg, other_ec); + SafeGetEC(arg, other_ec); if (!(ec = EC_KEY_dup(other_ec))) ossl_raise(eECError, NULL); } else if (rb_obj_is_kind_of(arg, cEC_GROUP)) { @@ -281,8 +245,6 @@ static VALUE ossl_ec_key_initialize(int argc, VALUE *argv, VALUE self) ossl_raise(eECError, "EVP_PKEY_assign_EC_KEY"); } - rb_ivar_set(self, id_i_group, Qnil); - return self; } @@ -295,7 +257,7 @@ ossl_ec_key_initialize_copy(VALUE self, VALUE other) GetPKey(self, pkey); if (EVP_PKEY_base_id(pkey) != EVP_PKEY_NONE) ossl_raise(eECError, "EC already initialized"); - SafeRequire_EC_KEY(other, ec); + SafeGetEC(other, ec); ec_new = EC_KEY_dup(ec); if (!ec_new) @@ -304,80 +266,46 @@ ossl_ec_key_initialize_copy(VALUE self, VALUE other) EC_KEY_free(ec_new); ossl_raise(eECError, "EVP_PKEY_assign_EC_KEY"); } - rb_ivar_set(self, id_i_group, Qnil); /* EC_KEY_dup() also copies the EC_GROUP */ return self; } /* - * call-seq: - * key.group => group + * call-seq: + * key.group => group * - * Returns a constant OpenSSL::EC::Group that is tied to the key. - * Modifying the returned group can make the key invalid. + * Returns the EC::Group that the key is associated with. Modifying the returned + * group does not affect +key+. */ -static VALUE ossl_ec_key_get_group(VALUE self) +static VALUE +ossl_ec_key_get_group(VALUE self) { - VALUE group_v; EC_KEY *ec; - ossl_ec_group *ec_group; - EC_GROUP *group; + const EC_GROUP *group; - Require_EC_KEY(self, ec); + GetEC(self, ec); + group = EC_KEY_get0_group(ec); + if (!group) + return Qnil; - group_v = rb_attr_get(self, id_i_group); - if (!NIL_P(group_v)) - return group_v; - - if ((group = (EC_GROUP *)EC_KEY_get0_group(ec)) != NULL) { - group_v = rb_obj_alloc(cEC_GROUP); - SafeGet_ec_group(group_v, ec_group); - ec_group->group = group; - ec_group->dont_free = 1; - rb_ivar_set(group_v, id_i_key, self); - rb_ivar_set(self, id_i_group, group_v); - return group_v; - } - - return Qnil; + return ec_group_new(group); } /* - * call-seq: - * key.group = group => group + * call-seq: + * key.group = group * - * Returns the same object passed, not the group object associated with the key. - * If you wish to access the group object tied to the key call key.group after setting - * the group. - * - * Setting the group will immediately destroy any previously assigned group object. - * The group is internally copied by OpenSSL. Modifying the original group after - * assignment will not effect the internal key structure. - * (your changes may be lost). BE CAREFUL. - * - * EC_KEY_set_group calls EC_GROUP_free(key->group) then EC_GROUP_dup(), not EC_GROUP_copy. - * This documentation is accurate for OpenSSL 0.9.8b. + * Sets the EC::Group for the key. The group structure is internally copied so + * modifition to +group+ after assigning to a key has no effect on the key. */ -static VALUE ossl_ec_key_set_group(VALUE self, VALUE group_v) +static VALUE +ossl_ec_key_set_group(VALUE self, VALUE group_v) { - VALUE old_group_v; EC_KEY *ec; EC_GROUP *group; - Require_EC_KEY(self, ec); - SafeRequire_EC_GROUP(group_v, group); - - old_group_v = rb_attr_get(self, id_i_group); - if (!NIL_P(old_group_v)) { - ossl_ec_group *old_ec_group; - SafeGet_ec_group(old_group_v, old_ec_group); - - old_ec_group->group = NULL; - old_ec_group->dont_free = 0; - rb_ivar_set(old_group_v, id_i_key, Qnil); - } - - rb_ivar_set(self, id_i_group, Qnil); + GetEC(self, ec); + SafeGetECGroup(group_v, group); if (EC_KEY_set_group(ec, group) != 1) ossl_raise(eECError, "EC_KEY_set_group"); @@ -396,8 +324,7 @@ static VALUE ossl_ec_key_get_private_key(VALUE self) EC_KEY *ec; const BIGNUM *bn; - Require_EC_KEY(self, ec); - + GetEC(self, ec); if ((bn = EC_KEY_get0_private_key(ec)) == NULL) return Qnil; @@ -415,7 +342,7 @@ static VALUE ossl_ec_key_set_private_key(VALUE self, VALUE private_key) EC_KEY *ec; BIGNUM *bn = NULL; - Require_EC_KEY(self, ec); + GetEC(self, ec); if (!NIL_P(private_key)) bn = GetBNPtr(private_key); @@ -432,26 +359,6 @@ static VALUE ossl_ec_key_set_private_key(VALUE self, VALUE private_key) return private_key; } - -static VALUE ossl_ec_point_dup(const EC_POINT *point, VALUE group_v) -{ - VALUE obj; - const EC_GROUP *group; - ossl_ec_point *new_point; - - obj = rb_obj_alloc(cEC_POINT); - TypedData_Get_Struct(obj, ossl_ec_point, &ossl_ec_point_type, new_point); - - SafeRequire_EC_GROUP(group_v, group); - - new_point->point = EC_POINT_dup(point, group); - if (new_point->point == NULL) - ossl_raise(eEC_POINT, "EC_POINT_dup"); - rb_ivar_set(obj, id_i_group, group_v); - - return obj; -} - /* * call-seq: * key.public_key => OpenSSL::PKey::EC::Point @@ -462,18 +369,12 @@ static VALUE ossl_ec_key_get_public_key(VALUE self) { EC_KEY *ec; const EC_POINT *point; - VALUE group; - - Require_EC_KEY(self, ec); + GetEC(self, ec); if ((point = EC_KEY_get0_public_key(ec)) == NULL) return Qnil; - group = rb_funcall(self, rb_intern("group"), 0); - if (NIL_P(group)) - ossl_raise(eECError, "EC_KEY_get0_get0_group (has public_key but no group???"); - - return ossl_ec_point_dup(point, group); + return ec_point_new(point, EC_KEY_get0_group(ec)); } /* @@ -487,9 +388,9 @@ static VALUE ossl_ec_key_set_public_key(VALUE self, VALUE public_key) EC_KEY *ec; EC_POINT *point = NULL; - Require_EC_KEY(self, ec); + GetEC(self, ec); if (!NIL_P(public_key)) - SafeRequire_EC_POINT(public_key, point); + SafeGetECPoint(public_key, point); switch (EC_KEY_set_public_key(ec, point)) { case 1: @@ -515,7 +416,7 @@ static VALUE ossl_ec_key_is_public(VALUE self) { EC_KEY *ec; - Require_EC_KEY(self, ec); + GetEC(self, ec); return EC_KEY_get0_public_key(ec) ? Qtrue : Qfalse; } @@ -531,7 +432,7 @@ static VALUE ossl_ec_key_is_private(VALUE self) { EC_KEY *ec; - Require_EC_KEY(self, ec); + GetEC(self, ec); return EC_KEY_get0_private_key(ec) ? Qtrue : Qfalse; } @@ -545,7 +446,7 @@ static VALUE ossl_ec_key_to_string(VALUE self, VALUE ciph, VALUE pass, int forma VALUE str; const EVP_CIPHER *cipher = NULL; - Require_EC_KEY(self, ec); + GetEC(self, ec); if (EC_KEY_get0_public_key(ec) == NULL) ossl_raise(eECError, "can't export - no public key set"); @@ -636,7 +537,7 @@ static VALUE ossl_ec_key_to_text(VALUE self) BIO *out; VALUE str; - Require_EC_KEY(self, ec); + GetEC(self, ec); if (!(out = BIO_new(BIO_s_mem()))) { ossl_raise(eECError, "BIO_new(BIO_s_mem())"); } @@ -667,8 +568,7 @@ static VALUE ossl_ec_key_generate_key(VALUE self) { EC_KEY *ec; - Require_EC_KEY(self, ec); - + GetEC(self, ec); if (EC_KEY_generate_key(ec) != 1) ossl_raise(eECError, "EC_KEY_generate_key"); @@ -687,8 +587,7 @@ static VALUE ossl_ec_key_check_key(VALUE self) { EC_KEY *ec; - Require_EC_KEY(self, ec); - + GetEC(self, ec); if (EC_KEY_check_key(ec) != 1) ossl_raise(eECError, "EC_KEY_check_key"); @@ -708,8 +607,8 @@ static VALUE ossl_ec_key_dh_compute_key(VALUE self, VALUE pubkey) int buf_len; VALUE str; - Require_EC_KEY(self, ec); - SafeRequire_EC_POINT(pubkey, point); + GetEC(self, ec); + SafeGetECPoint(pubkey, point); /* BUG: need a way to figure out the maximum string size */ buf_len = 1024; @@ -738,7 +637,7 @@ static VALUE ossl_ec_key_dsa_sign_asn1(VALUE self, VALUE data) unsigned int buf_len; VALUE str; - Require_EC_KEY(self, ec); + GetEC(self, ec); StringValue(data); if (EC_KEY_get0_private_key(ec) == NULL) @@ -763,7 +662,7 @@ static VALUE ossl_ec_key_dsa_verify_asn1(VALUE self, VALUE data, VALUE sig) { EC_KEY *ec; - Require_EC_KEY(self, ec); + GetEC(self, ec); StringValue(data); StringValue(sig); @@ -778,12 +677,13 @@ static VALUE ossl_ec_key_dsa_verify_asn1(VALUE self, VALUE data, VALUE sig) UNREACHABLE; } -static void ossl_ec_group_free(void *ptr) +/* + * OpenSSL::PKey::EC::Group + */ +static void +ossl_ec_group_free(void *ptr) { - ossl_ec_group *ec_group = ptr; - if (!ec_group->dont_free && ec_group->group) - EC_GROUP_clear_free(ec_group->group); - ruby_xfree(ec_group); + EC_GROUP_clear_free(ptr); } static const rb_data_type_t ossl_ec_group_type = { @@ -794,12 +694,23 @@ static const rb_data_type_t ossl_ec_group_type = { 0, 0, RUBY_TYPED_FREE_IMMEDIATELY, }; -static VALUE ossl_ec_group_alloc(VALUE klass) +static VALUE +ossl_ec_group_alloc(VALUE klass) { - ossl_ec_group *ec_group; - VALUE obj; + return TypedData_Wrap_Struct(klass, &ossl_ec_group_type, NULL); +} - obj = TypedData_Make_Struct(klass, ossl_ec_group, &ossl_ec_group_type, ec_group); +static VALUE +ec_group_new(const EC_GROUP *group) +{ + VALUE obj; + EC_GROUP *group_new; + + obj = ossl_ec_group_alloc(cEC_GROUP); + group_new = EC_GROUP_dup(group); + if (!group_new) + ossl_raise(eEC_GROUP, "EC_GROUP_dup"); + RTYPEDDATA_DATA(obj) = group_new; return obj; } @@ -828,11 +739,10 @@ static VALUE ossl_ec_group_alloc(VALUE klass) static VALUE ossl_ec_group_initialize(int argc, VALUE *argv, VALUE self) { VALUE arg1, arg2, arg3, arg4; - ossl_ec_group *ec_group; - EC_GROUP *group = NULL; + EC_GROUP *group; - TypedData_Get_Struct(self, ossl_ec_group, &ossl_ec_group_type, ec_group); - if (ec_group->group != NULL) + TypedData_Get_Struct(self, EC_GROUP, &ossl_ec_group_type, group); + if (group) ossl_raise(rb_eRuntimeError, "EC_GROUP is already initialized"); switch (rb_scan_args(argc, argv, "13", &arg1, &arg2, &arg3, &arg4)) { @@ -862,7 +772,7 @@ static VALUE ossl_ec_group_initialize(int argc, VALUE *argv, VALUE self) } else if (rb_obj_is_kind_of(arg1, cEC_GROUP)) { const EC_GROUP *arg1_group; - SafeRequire_EC_GROUP(arg1, arg1_group); + SafeGetECGroup(arg1, arg1_group); if ((group = EC_GROUP_dup(arg1_group)) == NULL) ossl_raise(eEC_GROUP, "EC_GROUP_dup"); } else { @@ -925,8 +835,7 @@ static VALUE ossl_ec_group_initialize(int argc, VALUE *argv, VALUE self) if (group == NULL) ossl_raise(eEC_GROUP, ""); - - ec_group->group = group; + RTYPEDDATA_DATA(self) = group; return self; } @@ -934,19 +843,17 @@ static VALUE ossl_ec_group_initialize(int argc, VALUE *argv, VALUE self) static VALUE ossl_ec_group_initialize_copy(VALUE self, VALUE other) { - ossl_ec_group *ec_group; - EC_GROUP *orig; + EC_GROUP *group, *group_new; - TypedData_Get_Struct(self, ossl_ec_group, &ossl_ec_group_type, ec_group); - if (ec_group->group) + TypedData_Get_Struct(self, EC_GROUP, &ossl_ec_group_type, group_new); + if (group_new) ossl_raise(eEC_GROUP, "EC::Group already initialized"); - SafeRequire_EC_GROUP(other, orig); + SafeGetECGroup(other, group); - ec_group->group = EC_GROUP_dup(orig); - if (!ec_group->group) + group_new = EC_GROUP_dup(group); + if (!group_new) ossl_raise(eEC_GROUP, "EC_GROUP_dup"); - - rb_ivar_set(self, id_i_key, Qnil); + RTYPEDDATA_DATA(self) = group_new; return self; } @@ -963,8 +870,8 @@ static VALUE ossl_ec_group_eql(VALUE a, VALUE b) { EC_GROUP *group1 = NULL, *group2 = NULL; - Require_EC_GROUP(a, group1); - SafeRequire_EC_GROUP(b, group2); + GetECGroup(a, group1); + SafeGetECGroup(b, group2); if (EC_GROUP_cmp(group1, group2, ossl_bn_ctx) == 1) return Qfalse; @@ -982,14 +889,15 @@ static VALUE ossl_ec_group_eql(VALUE a, VALUE b) */ static VALUE ossl_ec_group_get_generator(VALUE self) { - VALUE point_obj; - EC_GROUP *group = NULL; + EC_GROUP *group; + const EC_POINT *generator; - Require_EC_GROUP(self, group); + GetECGroup(self, group); + generator = EC_GROUP_get0_generator(group); + if (!generator) + return Qnil; - point_obj = ossl_ec_point_dup(EC_GROUP_get0_generator(group), self); - - return point_obj; + return ec_point_new(generator, group); } /* @@ -1007,8 +915,8 @@ static VALUE ossl_ec_group_set_generator(VALUE self, VALUE generator, VALUE orde const EC_POINT *point; const BIGNUM *o, *co; - Require_EC_GROUP(self, group); - SafeRequire_EC_POINT(generator, point); + GetECGroup(self, group); + SafeGetECPoint(generator, point); o = GetBNPtr(order); co = GetBNPtr(cofactor); @@ -1032,7 +940,7 @@ static VALUE ossl_ec_group_get_order(VALUE self) BIGNUM *bn; EC_GROUP *group = NULL; - Require_EC_GROUP(self, group); + GetECGroup(self, group); bn_obj = ossl_bn_new(NULL); bn = GetBNPtr(bn_obj); @@ -1057,7 +965,7 @@ static VALUE ossl_ec_group_get_cofactor(VALUE self) BIGNUM *bn; EC_GROUP *group = NULL; - Require_EC_GROUP(self, group); + GetECGroup(self, group); bn_obj = ossl_bn_new(NULL); bn = GetBNPtr(bn_obj); @@ -1081,7 +989,7 @@ static VALUE ossl_ec_group_get_curve_name(VALUE self) EC_GROUP *group = NULL; int nid; - Get_EC_GROUP(self, group); + GetECGroup(self, group); if (group == NULL) return Qnil; @@ -1141,8 +1049,7 @@ static VALUE ossl_ec_group_get_asn1_flag(VALUE self) EC_GROUP *group = NULL; int flag; - Require_EC_GROUP(self, group); - + GetECGroup(self, group); flag = EC_GROUP_get_asn1_flag(group); return INT2NUM(flag); @@ -1166,8 +1073,7 @@ static VALUE ossl_ec_group_set_asn1_flag(VALUE self, VALUE flag_v) { EC_GROUP *group = NULL; - Require_EC_GROUP(self, group); - + GetECGroup(self, group); EC_GROUP_set_asn1_flag(group, NUM2INT(flag_v)); return flag_v; @@ -1187,8 +1093,7 @@ static VALUE ossl_ec_group_get_point_conversion_form(VALUE self) point_conversion_form_t form; VALUE ret; - Require_EC_GROUP(self, group); - + GetECGroup(self, group); form = EC_GROUP_get_point_conversion_form(group); switch (form) { @@ -1226,7 +1131,7 @@ static VALUE ossl_ec_group_set_point_conversion_form(VALUE self, VALUE form_v) point_conversion_form_t form; ID form_id = SYM2ID(form_v); - Require_EC_GROUP(self, group); + GetECGroup(self, group); if (form_id == ID_uncompressed) { form = POINT_CONVERSION_UNCOMPRESSED; @@ -1254,8 +1159,7 @@ static VALUE ossl_ec_group_get_seed(VALUE self) EC_GROUP *group = NULL; size_t seed_len; - Require_EC_GROUP(self, group); - + GetECGroup(self, group); seed_len = EC_GROUP_get_seed_len(group); if (seed_len == 0) @@ -1274,7 +1178,7 @@ static VALUE ossl_ec_group_set_seed(VALUE self, VALUE seed) { EC_GROUP *group = NULL; - Require_EC_GROUP(self, group); + GetECGroup(self, group); StringValue(seed); if (EC_GROUP_set_seed(group, (unsigned char *)RSTRING_PTR(seed), RSTRING_LEN(seed)) != (size_t)RSTRING_LEN(seed)) @@ -1295,7 +1199,7 @@ static VALUE ossl_ec_group_get_degree(VALUE self) { EC_GROUP *group = NULL; - Require_EC_GROUP(self, group); + GetECGroup(self, group); return INT2NUM(EC_GROUP_get_degree(group)); } @@ -1307,7 +1211,7 @@ static VALUE ossl_ec_group_to_string(VALUE self, int format) int i = -1; VALUE str; - Get_EC_GROUP(self, group); + GetECGroup(self, group); if (!(out = BIO_new(BIO_s_mem()))) ossl_raise(eEC_GROUP, "BIO_new(BIO_s_mem())"); @@ -1368,7 +1272,7 @@ static VALUE ossl_ec_group_to_text(VALUE self) BIO *out; VALUE str; - Require_EC_GROUP(self, group); + GetECGroup(self, group); if (!(out = BIO_new(BIO_s_mem()))) { ossl_raise(eEC_GROUP, "BIO_new(BIO_s_mem())"); } @@ -1382,28 +1286,41 @@ static VALUE ossl_ec_group_to_text(VALUE self) } -static void ossl_ec_point_free(void *ptr) +/* + * OpenSSL::PKey::EC::Point + */ +static void +ossl_ec_point_free(void *ptr) { - ossl_ec_point *ec_point = ptr; - if (!ec_point->dont_free && ec_point->point) - EC_POINT_clear_free(ec_point->point); - ruby_xfree(ec_point); + EC_POINT_clear_free(ptr); } static const rb_data_type_t ossl_ec_point_type = { - "OpenSSL/ec_point", + "OpenSSL/EC_POINT", { 0, ossl_ec_point_free, }, 0, 0, RUBY_TYPED_FREE_IMMEDIATELY, }; -static VALUE ossl_ec_point_alloc(VALUE klass) +static VALUE +ossl_ec_point_alloc(VALUE klass) { - ossl_ec_point *ec_point; + return TypedData_Wrap_Struct(klass, &ossl_ec_point_type, NULL); +} + +static VALUE +ec_point_new(const EC_POINT *point, const EC_GROUP *group) +{ + EC_POINT *point_new; VALUE obj; - obj = TypedData_Make_Struct(klass, ossl_ec_point, &ossl_ec_point_type, ec_point); + obj = ossl_ec_point_alloc(cEC_POINT); + point_new = EC_POINT_dup(point, group); + if (!point_new) + ossl_raise(eEC_POINT, "EC_POINT_dup"); + RTYPEDDATA_DATA(obj) = point_new; + rb_ivar_set(obj, id_i_group, ec_group_new(group)); return obj; } @@ -1418,14 +1335,13 @@ static VALUE ossl_ec_point_alloc(VALUE klass) */ static VALUE ossl_ec_point_initialize(int argc, VALUE *argv, VALUE self) { - ossl_ec_point *ec_point; - EC_POINT *point = NULL; + EC_POINT *point; VALUE arg1, arg2; VALUE group_v = Qnil; const EC_GROUP *group = NULL; - TypedData_Get_Struct(self, ossl_ec_point, &ossl_ec_point_type, ec_point); - if (ec_point->point) + TypedData_Get_Struct(self, EC_POINT, &ossl_ec_point_type, point); + if (point) ossl_raise(eEC_POINT, "EC_POINT already initialized"); switch (rb_scan_args(argc, argv, "11", &arg1, &arg2)) { @@ -1434,13 +1350,13 @@ static VALUE ossl_ec_point_initialize(int argc, VALUE *argv, VALUE self) const EC_POINT *arg_point; group_v = rb_attr_get(arg1, id_i_group); - SafeRequire_EC_GROUP(group_v, group); - SafeRequire_EC_POINT(arg1, arg_point); + SafeGetECGroup(group_v, group); + SafeGetECPoint(arg1, arg_point); point = EC_POINT_dup(arg_point, group); } else if (rb_obj_is_kind_of(arg1, cEC_GROUP)) { group_v = arg1; - SafeRequire_EC_GROUP(group_v, group); + SafeGetECGroup(group_v, group); point = EC_POINT_new(group); } else { @@ -1452,7 +1368,7 @@ static VALUE ossl_ec_point_initialize(int argc, VALUE *argv, VALUE self) if (!rb_obj_is_kind_of(arg1, cEC_GROUP)) ossl_raise(rb_eArgError, "1st argument must be OpenSSL::PKey::EC::Group"); group_v = arg1; - SafeRequire_EC_GROUP(group_v, group); + SafeGetECGroup(group_v, group); if (rb_obj_is_kind_of(arg2, cBN)) { const BIGNUM *bn = GetBNPtr(arg2); @@ -1480,8 +1396,7 @@ static VALUE ossl_ec_point_initialize(int argc, VALUE *argv, VALUE self) if (NIL_P(group_v)) ossl_raise(rb_eRuntimeError, "missing group (internal error)"); - ec_point->point = point; - + RTYPEDDATA_DATA(self) = point; rb_ivar_set(self, id_i_group, group_v); return self; @@ -1490,23 +1405,22 @@ static VALUE ossl_ec_point_initialize(int argc, VALUE *argv, VALUE self) static VALUE ossl_ec_point_initialize_copy(VALUE self, VALUE other) { - ossl_ec_point *ec_point; - EC_POINT *orig; + EC_POINT *point, *point_new; EC_GROUP *group; VALUE group_v; - TypedData_Get_Struct(self, ossl_ec_point, &ossl_ec_point_type, ec_point); - if (ec_point->point) + TypedData_Get_Struct(self, EC_POINT, &ossl_ec_point_type, point_new); + if (point_new) ossl_raise(eEC_POINT, "EC::Point already initialized"); - SafeRequire_EC_POINT(other, orig); + SafeGetECPoint(other, point); group_v = rb_obj_dup(rb_attr_get(other, id_i_group)); - SafeRequire_EC_GROUP(group_v, group); + SafeGetECGroup(group_v, group); - ec_point->point = EC_POINT_dup(orig, group); - if (!ec_point->point) + point_new = EC_POINT_dup(point, group); + if (!point_new) ossl_raise(eEC_POINT, "EC_POINT_dup"); - rb_ivar_set(self, id_i_key, Qnil); + RTYPEDDATA_DATA(self) = point_new; rb_ivar_set(self, id_i_group, group_v); return self; @@ -1527,9 +1441,9 @@ static VALUE ossl_ec_point_eql(VALUE a, VALUE b) if (ossl_ec_group_eql(group_v1, group_v2) == Qfalse) return Qfalse; - Require_EC_POINT(a, point1); - SafeRequire_EC_POINT(b, point2); - SafeRequire_EC_GROUP(group_v1, group); + GetECPoint(a, point1); + SafeGetECPoint(b, point2); + SafeGetECGroup(group_v1, group); if (EC_POINT_cmp(group, point1, point2, ossl_bn_ctx) == 1) return Qfalse; @@ -1544,11 +1458,10 @@ static VALUE ossl_ec_point_eql(VALUE a, VALUE b) static VALUE ossl_ec_point_is_at_infinity(VALUE self) { EC_POINT *point; - VALUE group_v = rb_attr_get(self, id_i_group); const EC_GROUP *group; - Require_EC_POINT(self, point); - SafeRequire_EC_GROUP(group_v, group); + GetECPoint(self, point); + GetECPointGroup(self, group); switch (EC_POINT_is_at_infinity(group, point)) { case 1: return Qtrue; @@ -1566,11 +1479,10 @@ static VALUE ossl_ec_point_is_at_infinity(VALUE self) static VALUE ossl_ec_point_is_on_curve(VALUE self) { EC_POINT *point; - VALUE group_v = rb_attr_get(self, id_i_group); const EC_GROUP *group; - Require_EC_POINT(self, point); - SafeRequire_EC_GROUP(group_v, group); + GetECPoint(self, point); + GetECPointGroup(self, group); switch (EC_POINT_is_on_curve(group, point, ossl_bn_ctx)) { case 1: return Qtrue; @@ -1588,11 +1500,10 @@ static VALUE ossl_ec_point_is_on_curve(VALUE self) static VALUE ossl_ec_point_make_affine(VALUE self) { EC_POINT *point; - VALUE group_v = rb_attr_get(self, id_i_group); const EC_GROUP *group; - Require_EC_POINT(self, point); - SafeRequire_EC_GROUP(group_v, group); + GetECPoint(self, point); + GetECPointGroup(self, group); if (EC_POINT_make_affine(group, point, ossl_bn_ctx) != 1) ossl_raise(cEC_POINT, "EC_POINT_make_affine"); @@ -1607,11 +1518,10 @@ static VALUE ossl_ec_point_make_affine(VALUE self) static VALUE ossl_ec_point_invert(VALUE self) { EC_POINT *point; - VALUE group_v = rb_attr_get(self, id_i_group); const EC_GROUP *group; - Require_EC_POINT(self, point); - SafeRequire_EC_GROUP(group_v, group); + GetECPoint(self, point); + GetECPointGroup(self, group); if (EC_POINT_invert(group, point, ossl_bn_ctx) != 1) ossl_raise(cEC_POINT, "EC_POINT_invert"); @@ -1626,11 +1536,10 @@ static VALUE ossl_ec_point_invert(VALUE self) static VALUE ossl_ec_point_set_to_infinity(VALUE self) { EC_POINT *point; - VALUE group_v = rb_attr_get(self, id_i_group); const EC_GROUP *group; - Require_EC_POINT(self, point); - SafeRequire_EC_GROUP(group_v, group); + GetECPoint(self, point); + GetECPointGroup(self, group); if (EC_POINT_set_to_infinity(group, point) != 1) ossl_raise(cEC_POINT, "EC_POINT_set_to_infinity"); @@ -1648,13 +1557,12 @@ static VALUE ossl_ec_point_to_bn(VALUE self) { EC_POINT *point; VALUE bn_obj; - VALUE group_v = rb_attr_get(self, id_i_group); const EC_GROUP *group; point_conversion_form_t form; BIGNUM *bn; - Require_EC_POINT(self, point); - SafeRequire_EC_GROUP(group_v, group); + GetECPoint(self, point); + GetECPointGroup(self, group); form = EC_GROUP_get_point_conversion_form(group); @@ -1692,12 +1600,12 @@ static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self) VALUE arg1, arg2, arg3, result; const BIGNUM *bn_g = NULL; - Require_EC_POINT(self, point_self); - SafeRequire_EC_GROUP(group_v, group); + GetECPoint(self, point_self); + SafeGetECGroup(group_v, group); result = rb_obj_alloc(cEC_POINT); ossl_ec_point_initialize(1, &group_v, result); - Require_EC_POINT(result, point_result); + GetECPoint(result, point_result); rb_scan_args(argc, argv, "12", &arg1, &arg2, &arg3); if (!RB_TYPE_P(arg1, T_ARRAY)) { @@ -1730,7 +1638,7 @@ static VALUE ossl_ec_point_mul(int argc, VALUE *argv, VALUE self) points = ALLOCV_N(const EC_POINT *, tmp_p, num); points[0] = point_self; /* self */ for (i = 0; i < num - 1; i++) - SafeRequire_EC_POINT(RARRAY_AREF(arg2, i), points[i + 1]); + SafeGetECPoint(RARRAY_AREF(arg2, i), points[i + 1]); if (!NIL_P(arg3)) bn_g = GetBNPtr(arg3); @@ -1889,7 +1797,6 @@ void Init_ossl_ec(void) rb_define_method(cEC_POINT, "mul", ossl_ec_point_mul, -1); id_i_group = rb_intern("@group"); - id_i_key = rb_intern("@key"); } #else /* defined NO_EC */ diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index 9a499a7a0d..053613adae 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -36,50 +36,19 @@ VALUE cSSLSocket; static VALUE eSSLErrorWaitReadable; static VALUE eSSLErrorWaitWritable; -#define ossl_sslctx_set_cert(o,v) rb_iv_set((o),"@cert",(v)) -#define ossl_sslctx_set_key(o,v) rb_iv_set((o),"@key",(v)) -#define ossl_sslctx_set_client_ca(o,v) rb_iv_set((o),"@client_ca",(v)) -#define ossl_sslctx_set_ca_file(o,v) rb_iv_set((o),"@ca_file",(v)) -#define ossl_sslctx_set_ca_path(o,v) rb_iv_set((o),"@ca_path",(v)) -#define ossl_sslctx_set_timeout(o,v) rb_iv_set((o),"@timeout",(v)) -#define ossl_sslctx_set_verify_mode(o,v) rb_iv_set((o),"@verify_mode",(v)) -#define ossl_sslctx_set_verify_dep(o,v) rb_iv_set((o),"@verify_depth",(v)) -#define ossl_sslctx_set_verify_cb(o,v) rb_iv_set((o),"@verify_callback",(v)) -#define ossl_sslctx_set_cert_store(o,v) rb_iv_set((o),"@cert_store",(v)) -#define ossl_sslctx_set_extra_cert(o,v) rb_iv_set((o),"@extra_chain_cert",(v)) -#define ossl_sslctx_set_client_cert_cb(o,v) rb_iv_set((o),"@client_cert_cb",(v)) -#define ossl_sslctx_set_sess_id_ctx(o, v) rb_iv_set((o),"@session_id_context",(v)) - -#define ossl_sslctx_get_cert(o) rb_iv_get((o),"@cert") -#define ossl_sslctx_get_key(o) rb_iv_get((o),"@key") -#define ossl_sslctx_get_client_ca(o) rb_iv_get((o),"@client_ca") -#define ossl_sslctx_get_ca_file(o) rb_iv_get((o),"@ca_file") -#define ossl_sslctx_get_ca_path(o) rb_iv_get((o),"@ca_path") -#define ossl_sslctx_get_timeout(o) rb_iv_get((o),"@timeout") -#define ossl_sslctx_get_verify_mode(o) rb_iv_get((o),"@verify_mode") -#define ossl_sslctx_get_verify_dep(o) rb_iv_get((o),"@verify_depth") -#define ossl_sslctx_get_verify_cb(o) rb_iv_get((o),"@verify_callback") -#define ossl_sslctx_get_cert_store(o) rb_iv_get((o),"@cert_store") -#define ossl_sslctx_get_extra_cert(o) rb_iv_get((o),"@extra_chain_cert") -#define ossl_sslctx_get_client_cert_cb(o) rb_iv_get((o),"@client_cert_cb") -#define ossl_sslctx_get_tmp_ecdh_cb(o) rb_iv_get((o),"@tmp_ecdh_callback") -#define ossl_sslctx_get_sess_id_ctx(o) rb_iv_get((o),"@session_id_context") -#define ossl_sslctx_get_verify_hostname(o) rb_iv_get((o),"@verify_hostname") - -#define ossl_ssl_get_io(o) rb_iv_get((o),"@io") -#define ossl_ssl_get_ctx(o) rb_iv_get((o),"@context") - -#define ossl_ssl_set_io(o,v) rb_iv_set((o),"@io",(v)) -#define ossl_ssl_set_ctx(o,v) rb_iv_set((o),"@context",(v)) -#define ossl_ssl_set_sync_close(o,v) rb_iv_set((o),"@sync_close",(v)) -#define ossl_ssl_set_hostname_v(o,v) rb_iv_set((o),"@hostname",(v)) -#define ossl_ssl_set_tmp_dh(o,v) rb_iv_set((o),"@tmp_dh",(v)) -#define ossl_ssl_set_tmp_ecdh(o,v) rb_iv_set((o),"@tmp_ecdh",(v)) - static ID ID_callback_state; - static VALUE sym_exception, sym_wait_readable, sym_wait_writable; +static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode, + id_i_verify_depth, id_i_verify_callback, id_i_client_ca, + id_i_renegotiation_cb, id_i_cert, id_i_key, id_i_extra_chain_cert, + id_i_client_cert_cb, id_i_tmp_ecdh_callback, id_i_timeout, + id_i_session_id_context, id_i_session_get_cb, id_i_session_new_cb, + id_i_session_remove_cb, id_i_npn_select_cb, id_i_npn_protocols, + id_i_alpn_select_cb, id_i_alpn_protocols, id_i_servername_cb, + id_i_verify_hostname; +static ID id_i_io, id_i_context, id_i_hostname; + /* * SSLContext class */ @@ -223,9 +192,10 @@ ossl_sslctx_set_ssl_version(VALUE self, VALUE ssl_method) static VALUE ossl_call_client_cert_cb(VALUE obj) { - VALUE cb, ary, cert, key; + VALUE ctx_obj, cb, ary, cert, key; - cb = ossl_sslctx_get_client_cert_cb(ossl_ssl_get_ctx(obj)); + ctx_obj = rb_attr_get(obj, id_i_context); + cb = rb_attr_get(ctx_obj, id_i_client_cert_cb); if (NIL_P(cb)) return Qnil; @@ -281,7 +251,6 @@ ossl_tmp_dh_callback(SSL *ssl, int is_export, int keylength) dh = rb_protect(ossl_call_tmp_dh_callback, args, NULL); if (!RTEST(dh)) return NULL; - ossl_ssl_set_tmp_dh(rb_ssl, dh); return EVP_PKEY_get0_DH(GetPKeyPtr(dh)); } @@ -315,7 +284,6 @@ ossl_tmp_ecdh_callback(SSL *ssl, int is_export, int keylength) ecdh = rb_protect(ossl_call_tmp_ecdh_callback, args, NULL); if (!RTEST(ecdh)) return NULL; - ossl_ssl_set_tmp_ecdh(rb_ssl, ecdh); return EVP_PKEY_get0_EC_KEY(GetPKeyPtr(ecdh)); } @@ -330,7 +298,7 @@ call_verify_certificate_identity(VALUE ctx_v) ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - hostname = rb_attr_get(ssl_obj, rb_intern("@hostname")); + hostname = rb_attr_get(ssl_obj, id_i_hostname); if (!RTEST(hostname)) { rb_warning("verify_hostname requires hostname to be set"); @@ -345,14 +313,15 @@ call_verify_certificate_identity(VALUE ctx_v) static int ossl_ssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { - VALUE cb, ssl_obj, verify_hostname, ret; + VALUE cb, ssl_obj, sslctx_obj, verify_hostname, ret; SSL *ssl; int status; ssl = X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); cb = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_vcb_idx); ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); - verify_hostname = ossl_sslctx_get_verify_hostname(ossl_ssl_get_ctx(ssl_obj)); + sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + verify_hostname = rb_attr_get(sslctx_obj, id_i_verify_hostname); if (preverify_ok && RTEST(verify_hostname) && !SSL_is_server(ssl) && !X509_STORE_CTX_get_error_depth(ctx)) { @@ -474,7 +443,7 @@ ossl_call_session_remove_cb(VALUE ary) Check_Type(ary, T_ARRAY); sslctx_obj = rb_ary_entry(ary, 0); - cb = rb_iv_get(sslctx_obj, "@session_remove_cb"); + cb = rb_attr_get(sslctx_obj, id_i_session_remove_cb); if (NIL_P(cb)) return Qnil; return rb_funcall(cb, rb_intern("call"), 1, ary); @@ -536,9 +505,8 @@ ossl_call_servername_cb(VALUE ary) Check_Type(ary, T_ARRAY); ssl_obj = rb_ary_entry(ary, 0); - sslctx_obj = rb_iv_get(ssl_obj, "@context"); - if (NIL_P(sslctx_obj)) return Qnil; - cb = rb_iv_get(sslctx_obj, "@servername_cb"); + sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + cb = rb_attr_get(sslctx_obj, id_i_servername_cb); if (NIL_P(cb)) return Qnil; ret_obj = rb_funcall(cb, rb_intern("call"), 1, ary); @@ -550,9 +518,10 @@ ossl_call_servername_cb(VALUE ary) GetSSL(ssl_obj, ssl); GetSSLCTX(ret_obj, ctx2); SSL_set_SSL_CTX(ssl, ctx2); - rb_iv_set(ssl_obj, "@context", ret_obj); + rb_ivar_set(ssl_obj, id_i_context, ret_obj); } else if (!NIL_P(ret_obj)) { - ossl_raise(rb_eArgError, "servername_cb must return an OpenSSL::SSL::SSLContext object or nil"); + ossl_raise(rb_eArgError, "servername_cb must return an " + "OpenSSL::SSL::SSLContext object or nil"); } return ret_obj; @@ -596,15 +565,15 @@ ssl_renegotiation_cb(const SSL *ssl) ossl_raise(eSSLError, "SSL object could not be retrieved"); ssl_obj = (VALUE)ptr; - sslctx_obj = rb_iv_get(ssl_obj, "@context"); - if (NIL_P(sslctx_obj)) return; - cb = rb_iv_get(sslctx_obj, "@renegotiation_cb"); + sslctx_obj = rb_attr_get(ssl_obj, id_i_context); + cb = rb_attr_get(sslctx_obj, id_i_renegotiation_cb); if (NIL_P(cb)) return; (void) rb_funcall(cb, rb_intern("call"), 1, ssl_obj); } -#if defined(HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB) || defined(HAVE_SSL_CTX_SET_ALPN_SELECT_CB) +#if defined(HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB) || \ + defined(HAVE_SSL_CTX_SET_ALPN_SELECT_CB) static VALUE ssl_npn_encode_protocol_i(VALUE cur, VALUE encoded) { @@ -627,14 +596,20 @@ ssl_encode_npn_protocols(VALUE protocols) return encoded; } -static int -ssl_npn_select_cb_common(VALUE cb, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen) +struct npn_select_cb_common_args { + VALUE cb; + const unsigned char *in; + unsigned inlen; +}; + +static VALUE +npn_select_cb_common_i(VALUE tmp) { - VALUE selected; - long len; - VALUE protocols = rb_ary_new(); + struct npn_select_cb_common_args *args = (void *)tmp; + const unsigned char *in = args->in, *in_end = in + args->inlen; unsigned char l; - const unsigned char *in_end = in + inlen; + long len; + VALUE selected, protocols = rb_ary_new(); /* assume OpenSSL verifies this format */ /* The format is len_1|proto_1|...|len_n|proto_n */ @@ -644,21 +619,44 @@ ssl_npn_select_cb_common(VALUE cb, const unsigned char **out, unsigned char *out in += l; } - selected = rb_funcall(cb, rb_intern("call"), 1, protocols); + selected = rb_funcall(args->cb, rb_intern("call"), 1, protocols); StringValue(selected); len = RSTRING_LEN(selected); if (len < 1 || len >= 256) { ossl_raise(eSSLError, "Selected protocol name must have length 1..255"); } + + return selected; +} + +static int +ssl_npn_select_cb_common(SSL *ssl, VALUE cb, const unsigned char **out, + unsigned char *outlen, const unsigned char *in, + unsigned int inlen) +{ + VALUE selected; + int status; + struct npn_select_cb_common_args args = { cb, in, inlen }; + + selected = rb_protect(npn_select_cb_common_i, (VALUE)&args, &status); + if (status) { + VALUE ssl_obj = (VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx); + + rb_ivar_set(ssl_obj, ID_callback_state, INT2NUM(status)); + return SSL_TLSEXT_ERR_ALERT_FATAL; + } + *out = (unsigned char *)RSTRING_PTR(selected); - *outlen = (unsigned char)len; + *outlen = (unsigned char)RSTRING_LEN(selected); return SSL_TLSEXT_ERR_OK; } +#endif #ifdef HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB static int -ssl_npn_advertise_cb(SSL *ssl, const unsigned char **out, unsigned int *outlen, void *arg) +ssl_npn_advertise_cb(SSL *ssl, const unsigned char **out, unsigned int *outlen, + void *arg) { VALUE protocols = (VALUE)arg; @@ -669,30 +667,32 @@ ssl_npn_advertise_cb(SSL *ssl, const unsigned char **out, unsigned int *outlen, } static int -ssl_npn_select_cb(SSL *s, unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg) +ssl_npn_select_cb(SSL *ssl, unsigned char **out, unsigned char *outlen, + const unsigned char *in, unsigned int inlen, void *arg) { VALUE sslctx_obj, cb; sslctx_obj = (VALUE) arg; - cb = rb_iv_get(sslctx_obj, "@npn_select_cb"); + cb = rb_attr_get(sslctx_obj, id_i_npn_select_cb); - return ssl_npn_select_cb_common(cb, (const unsigned char **)out, outlen, in, inlen); + return ssl_npn_select_cb_common(ssl, cb, (const unsigned char **)out, + outlen, in, inlen); } #endif #ifdef HAVE_SSL_CTX_SET_ALPN_SELECT_CB static int -ssl_alpn_select_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg) +ssl_alpn_select_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, + const unsigned char *in, unsigned int inlen, void *arg) { VALUE sslctx_obj, cb; sslctx_obj = (VALUE) arg; - cb = rb_iv_get(sslctx_obj, "@alpn_select_cb"); + cb = rb_attr_get(sslctx_obj, id_i_alpn_select_cb); - return ssl_npn_select_cb_common(cb, out, outlen, in, inlen); + return ssl_npn_select_cb_common(ssl, cb, out, outlen, in, inlen); } #endif -#endif /* HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB || HAVE_SSL_CTX_SET_ALPN_SELECT_CB */ /* This function may serve as the entry point to support further callbacks. */ static void @@ -768,7 +768,7 @@ ossl_sslctx_setup(VALUE self) #if !defined(OPENSSL_NO_EC) /* We added SSLContext#tmp_ecdh_callback= in Ruby 2.3.0, * but SSL_CTX_set_tmp_ecdh_callback() was removed in OpenSSL 1.1.0. */ - if (RTEST(ossl_sslctx_get_tmp_ecdh_cb(self))) { + if (RTEST(rb_attr_get(self, id_i_tmp_ecdh_callback))) { # if defined(HAVE_SSL_CTX_SET_TMP_ECDH_CALLBACK) rb_warn("#tmp_ecdh_callback= is deprecated; use #ecdh_curves= instead"); SSL_CTX_set_tmp_ecdh_callback(ctx, ossl_tmp_ecdh_callback); @@ -785,7 +785,7 @@ ossl_sslctx_setup(VALUE self) } #endif /* OPENSSL_NO_EC */ - val = ossl_sslctx_get_cert_store(self); + val = rb_attr_get(self, id_i_cert_store); if (!NIL_P(val)) { X509_STORE *store = GetX509StorePtr(val); /* NO NEED TO DUP */ SSL_CTX_set_cert_store(ctx, store); @@ -802,15 +802,15 @@ ossl_sslctx_setup(VALUE self) #endif } - val = ossl_sslctx_get_extra_cert(self); + val = rb_attr_get(self, id_i_extra_chain_cert); if(!NIL_P(val)){ rb_block_call(val, rb_intern("each"), 0, 0, ossl_sslctx_add_extra_chain_cert_i, self); } /* private key may be bundled in certificate file. */ - val = ossl_sslctx_get_cert(self); + val = rb_attr_get(self, id_i_cert); cert = NIL_P(val) ? NULL : GetX509CertPtr(val); /* NO DUP NEEDED */ - val = ossl_sslctx_get_key(self); + val = rb_attr_get(self, id_i_key); key = NIL_P(val) ? NULL : GetPrivPKeyPtr(val); /* NO DUP NEEDED */ if (cert && key) { if (!SSL_CTX_use_certificate(ctx, cert)) { @@ -826,7 +826,7 @@ ossl_sslctx_setup(VALUE self) } } - val = ossl_sslctx_get_client_ca(self); + val = rb_attr_get(self, id_i_client_ca); if(!NIL_P(val)){ if (RB_TYPE_P(val, T_ARRAY)) { for(i = 0; i < RARRAY_LEN(val); i++){ @@ -846,48 +846,52 @@ ossl_sslctx_setup(VALUE self) } } - val = ossl_sslctx_get_ca_file(self); + val = rb_attr_get(self, id_i_ca_file); ca_file = NIL_P(val) ? NULL : StringValueCStr(val); - val = ossl_sslctx_get_ca_path(self); + val = rb_attr_get(self, id_i_ca_path); ca_path = NIL_P(val) ? NULL : StringValueCStr(val); if(ca_file || ca_path){ if (!SSL_CTX_load_verify_locations(ctx, ca_file, ca_path)) rb_warning("can't set verify locations"); } - val = ossl_sslctx_get_verify_mode(self); + val = rb_attr_get(self, id_i_verify_mode); verify_mode = NIL_P(val) ? SSL_VERIFY_NONE : NUM2INT(val); SSL_CTX_set_verify(ctx, verify_mode, ossl_ssl_verify_callback); - if (RTEST(ossl_sslctx_get_client_cert_cb(self))) + if (RTEST(rb_attr_get(self, id_i_client_cert_cb))) SSL_CTX_set_client_cert_cb(ctx, ossl_client_cert_cb); - val = ossl_sslctx_get_timeout(self); + val = rb_attr_get(self, id_i_timeout); if(!NIL_P(val)) SSL_CTX_set_timeout(ctx, NUM2LONG(val)); - val = ossl_sslctx_get_verify_dep(self); + val = rb_attr_get(self, id_i_verify_depth); if(!NIL_P(val)) SSL_CTX_set_verify_depth(ctx, NUM2INT(val)); #ifdef HAVE_SSL_CTX_SET_NEXT_PROTO_SELECT_CB - val = rb_iv_get(self, "@npn_protocols"); + val = rb_attr_get(self, id_i_npn_protocols); if (!NIL_P(val)) { VALUE encoded = ssl_encode_npn_protocols(val); SSL_CTX_set_next_protos_advertised_cb(ctx, ssl_npn_advertise_cb, (void *)encoded); OSSL_Debug("SSL NPN advertise callback added"); } - if (RTEST(rb_iv_get(self, "@npn_select_cb"))) { + if (RTEST(rb_attr_get(self, id_i_npn_select_cb))) { SSL_CTX_set_next_proto_select_cb(ctx, ssl_npn_select_cb, (void *) self); OSSL_Debug("SSL NPN select callback added"); } #endif #ifdef HAVE_SSL_CTX_SET_ALPN_SELECT_CB - val = rb_iv_get(self, "@alpn_protocols"); + val = rb_attr_get(self, id_i_alpn_protocols); if (!NIL_P(val)) { VALUE rprotos = ssl_encode_npn_protocols(val); - SSL_CTX_set_alpn_protos(ctx, (unsigned char *)RSTRING_PTR(rprotos), RSTRING_LENINT(rprotos)); + + /* returns 0 on success */ + if (SSL_CTX_set_alpn_protos(ctx, (unsigned char *)RSTRING_PTR(rprotos), + RSTRING_LENINT(rprotos))) + ossl_raise(eSSLError, "SSL_CTX_set_alpn_protos"); OSSL_Debug("SSL ALPN values added"); } - if (RTEST(rb_iv_get(self, "@alpn_select_cb"))) { + if (RTEST(rb_attr_get(self, id_i_alpn_select_cb))) { SSL_CTX_set_alpn_select_cb(ctx, ssl_alpn_select_cb, (void *) self); OSSL_Debug("SSL ALPN select callback added"); } @@ -895,7 +899,7 @@ ossl_sslctx_setup(VALUE self) rb_obj_freeze(self); - val = ossl_sslctx_get_sess_id_ctx(self); + val = rb_attr_get(self, id_i_session_id_context); if (!NIL_P(val)){ StringValue(val); if (!SSL_CTX_set_session_id_context(ctx, (unsigned char *)RSTRING_PTR(val), @@ -904,21 +908,21 @@ ossl_sslctx_setup(VALUE self) } } - if (RTEST(rb_iv_get(self, "@session_get_cb"))) { + if (RTEST(rb_attr_get(self, id_i_session_get_cb))) { SSL_CTX_sess_set_get_cb(ctx, ossl_sslctx_session_get_cb); OSSL_Debug("SSL SESSION get callback added"); } - if (RTEST(rb_iv_get(self, "@session_new_cb"))) { + if (RTEST(rb_attr_get(self, id_i_session_new_cb))) { SSL_CTX_sess_set_new_cb(ctx, ossl_sslctx_session_new_cb); OSSL_Debug("SSL SESSION new callback added"); } - if (RTEST(rb_iv_get(self, "@session_remove_cb"))) { + if (RTEST(rb_attr_get(self, id_i_session_remove_cb))) { SSL_CTX_sess_set_remove_cb(ctx, ossl_sslctx_session_remove_cb); OSSL_Debug("SSL SESSION remove callback added"); } #ifdef HAVE_SSL_SET_TLSEXT_HOST_NAME - val = rb_iv_get(self, "@servername_cb"); + val = rb_attr_get(self, id_i_servername_cb); if (!NIL_P(val)) { SSL_CTX_set_tlsext_servername_callback(ctx, ssl_servername_cb); OSSL_Debug("SSL TLSEXT servername callback added"); @@ -1437,14 +1441,12 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) v_ctx = rb_funcall(cSSLContext, rb_intern("new"), 0); GetSSLCTX(v_ctx, ctx); - ossl_ssl_set_ctx(self, v_ctx); + rb_ivar_set(self, id_i_context, v_ctx); ossl_sslctx_setup(v_ctx); if (rb_respond_to(io, rb_intern("nonblock="))) rb_funcall(io, rb_intern("nonblock="), 1, Qtrue); - ossl_ssl_set_io(self, io); - - ossl_ssl_set_sync_close(self, Qfalse); + rb_ivar_set(self, id_i_io, io); ssl = SSL_new(ctx); if (!ssl) @@ -1453,7 +1455,7 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) SSL_set_ex_data(ssl, ossl_ssl_ex_ptr_idx, (void *)self); SSL_set_info_callback(ssl, ssl_info_cb); - verify_cb = ossl_sslctx_get_verify_cb(v_ctx); + verify_cb = rb_attr_get(v_ctx, id_i_verify_callback); SSL_set_ex_data(ssl, ossl_ssl_ex_vcb_idx, (void *)verify_cb); rb_call_super(0, NULL); @@ -1472,7 +1474,7 @@ ossl_ssl_setup(VALUE self) if (ssl_started(ssl)) return Qtrue; - io = ossl_ssl_get_io(self); + io = rb_attr_get(self, id_i_io); GetOpenFile(io, fptr); rb_io_check_readable(fptr); rb_io_check_writable(fptr); @@ -1527,11 +1529,11 @@ ossl_start_ssl(VALUE self, int (*func)(), const char *funcname, VALUE opts) GetSSL(self, ssl); - GetOpenFile(ossl_ssl_get_io(self), fptr); + GetOpenFile(rb_attr_get(self, id_i_io), fptr); for(;;){ ret = func(ssl); - cb_state = rb_ivar_get(self, ID_callback_state); + cb_state = rb_attr_get(self, ID_callback_state); if (!NIL_P(cb_state)) { /* must cleanup OpenSSL error stack before re-raising */ ossl_clear_error(); @@ -1666,7 +1668,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) int ilen, nread = 0; VALUE len, str; rb_io_t *fptr; - VALUE opts = Qnil; + VALUE io, opts = Qnil; if (nonblock) { rb_scan_args(argc, argv, "11:", &len, &str, &opts); @@ -1684,7 +1686,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) if(ilen == 0) return str; GetSSL(self, ssl); - GetOpenFile(ossl_ssl_get_io(self), fptr); + io = rb_attr_get(self, id_i_io); + GetOpenFile(io, fptr); if (ssl_started(ssl)) { if(!nonblock && SSL_pending(ssl) <= 0) rb_thread_wait_fd(FPTR_TO_FD(fptr)); @@ -1718,13 +1721,13 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) } } else { - ID meth = nonblock ? rb_intern("read_nonblock") : rb_intern("sysread"); - rb_warning("SSL session is not started yet."); - if (nonblock) { - return rb_funcall(ossl_ssl_get_io(self), meth, 3, len, str, opts); - } else { - return rb_funcall(ossl_ssl_get_io(self), meth, 2, len, str); - } + ID meth = nonblock ? rb_intern("read_nonblock") : rb_intern("sysread"); + + rb_warning("SSL session is not started yet."); + if (nonblock) + return rb_funcall(io, meth, 3, len, str, opts); + else + return rb_funcall(io, meth, 2, len, str); } end: @@ -1774,11 +1777,12 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) int nwrite = 0; rb_io_t *fptr; int nonblock = opts != Qfalse; + VALUE io; StringValue(str); GetSSL(self, ssl); - GetOpenFile(ossl_ssl_get_io(self), fptr); - + io = rb_attr_get(self, id_i_io); + GetOpenFile(io, fptr); if (ssl_started(ssl)) { for (;;){ int num = RSTRING_LENINT(str); @@ -1809,9 +1813,14 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) } } else { - ID id_syswrite = rb_intern("syswrite"); - rb_warning("SSL session is not started yet."); - return rb_funcall(ossl_ssl_get_io(self), id_syswrite, 1, str); + ID meth = nonblock ? + rb_intern("write_nonblock") : rb_intern("syswrite"); + + rb_warning("SSL session is not started yet."); + if (nonblock) + return rb_funcall(io, meth, 2, str, opts); + else + return rb_funcall(io, meth, 1, str); } end: @@ -2082,7 +2091,7 @@ ossl_ssl_set_hostname(VALUE self, VALUE arg) ossl_raise(eSSLError, NULL); /* for SSLSocket#hostname */ - ossl_ssl_set_hostname_v(self, arg); + rb_ivar_set(self, id_i_hostname, arg); return arg; } @@ -2201,6 +2210,8 @@ ossl_ssl_tmp_key(VALUE self) # endif /* defined(HAVE_SSL_GET_SERVER_TMP_KEY) */ #endif /* !defined(OPENSSL_NO_SOCK) */ +#undef rb_intern +#define rb_intern(s) rb_intern_const(s) void Init_ossl_ssl(void) { @@ -2214,7 +2225,7 @@ Init_ossl_ssl(void) rb_mWaitWritable = rb_define_module_under(rb_cIO, "WaitWritable"); #endif - ID_callback_state = rb_intern("@callback_state"); + ID_callback_state = rb_intern("callback_state"); ossl_ssl_ex_vcb_idx = SSL_get_ex_new_index(0,(void *)"ossl_ssl_ex_vcb_idx",0,0,0); ossl_ssl_ex_store_p = SSL_get_ex_new_index(0,(void *)"ossl_ssl_ex_store_p",0,0,0); @@ -2672,8 +2683,39 @@ Init_ossl_ssl(void) ossl_ssl_def_const(OP_NETSCAPE_CA_DN_BUG); ossl_ssl_def_const(OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG); -#undef rb_intern sym_exception = ID2SYM(rb_intern("exception")); sym_wait_readable = ID2SYM(rb_intern("wait_readable")); sym_wait_writable = ID2SYM(rb_intern("wait_writable")); + +#define DefIVarID(name) do \ + id_i_##name = rb_intern("@"#name); while (0) + + DefIVarID(cert_store); + DefIVarID(ca_file); + DefIVarID(ca_path); + DefIVarID(verify_mode); + DefIVarID(verify_depth); + DefIVarID(verify_callback); + DefIVarID(client_ca); + DefIVarID(renegotiation_cb); + DefIVarID(cert); + DefIVarID(key); + DefIVarID(extra_chain_cert); + DefIVarID(client_cert_cb); + DefIVarID(tmp_ecdh_callback); + DefIVarID(timeout); + DefIVarID(session_id_context); + DefIVarID(session_get_cb); + DefIVarID(session_new_cb); + DefIVarID(session_remove_cb); + DefIVarID(npn_select_cb); + DefIVarID(npn_protocols); + DefIVarID(alpn_protocols); + DefIVarID(alpn_select_cb); + DefIVarID(servername_cb); + DefIVarID(verify_hostname); + + DefIVarID(io); + DefIVarID(context); + DefIVarID(hostname); } diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb index 9250222979..7d962c385a 100644 --- a/test/openssl/test_pair.rb +++ b/test/openssl/test_pair.rb @@ -113,11 +113,25 @@ module OpenSSL::TestPairM } end + def test_gets + ssl_pair {|s1, s2| + s1 << "abc\n\n$def123ghi" + s1.close + ret = s2.gets + assert_equal Encoding::BINARY, ret.encoding + assert_equal "abc\n", ret + assert_equal "\n$", s2.gets("$") + assert_equal "def123", s2.gets(/\d+/) + assert_equal "ghi", s2.gets + assert_equal nil, s2.gets + } + end + def test_gets_eof_limit ssl_pair {|s1, s2| s1.write("hello") s1.close # trigger EOF - assert_match "hello", s2.gets("\n", 6), "[ruby-core:70149] [Bug #11140]" + assert_match "hello", s2.gets("\n", 6), "[ruby-core:70149] [Bug #11400]" } end @@ -344,147 +358,6 @@ module OpenSSL::TestPairM serv.close if serv && !serv.closed? end - def test_connect_works_when_setting_dh_callback_to_nil - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.ciphers = "DH" - ctx2.security_level = 0 - ctx2.tmp_dh_callback = nil - sock1, sock2 = tcp_pair - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - s2.accept_nonblock(exception: false) - - ctx1 = OpenSSL::SSL::SSLContext.new - ctx1.ciphers = "DH" - ctx1.security_level = 0 - ctx1.tmp_dh_callback = nil - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - t = Thread.new { s1.connect } - - EnvUtil.suppress_warning { # uses default callback - assert_nothing_raised { s2.accept } - } - assert_equal s1, t.value - ensure - t.join if t - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - end - - def test_connect_without_setting_dh_callback - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.ciphers = "DH" - ctx2.security_level = 0 - sock1, sock2 = tcp_pair - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - s2.accept_nonblock(exception: false) - - ctx1 = OpenSSL::SSL::SSLContext.new - ctx1.ciphers = "DH" - ctx1.security_level = 0 - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - t = Thread.new { s1.connect } - - EnvUtil.suppress_warning { # default DH - assert_nothing_raised { s2.accept } - } - assert_equal s1, t.value - ensure - t.join if t - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - end - - def test_ecdh_callback - return unless OpenSSL::SSL::SSLContext.instance_methods.include?(:tmp_ecdh_callback) - EnvUtil.suppress_warning do # tmp_ecdh_callback is deprecated (2016-05) - begin - called = false - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.ciphers = "ECDH" - # OpenSSL 1.1.0 doesn't have tmp_ecdh_callback so this shouldn't be required - ctx2.security_level = 0 - ctx2.tmp_ecdh_callback = ->(*args) { - called = true - OpenSSL::PKey::EC.new "prime256v1" - } - - sock1, sock2 = tcp_pair - - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - ctx1 = OpenSSL::SSL::SSLContext.new - ctx1.ciphers = "ECDH" - ctx1.security_level = 0 - - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - th = Thread.new do - begin - rv = s1.connect_nonblock(exception: false) - case rv - when :wait_writable - IO.select(nil, [s1], nil, 5) - when :wait_readable - IO.select([s1], nil, nil, 5) - end - end until rv == s1 - end - - s2.accept - assert called, 'ecdh callback should be called' - rescue OpenSSL::SSL::SSLError => e - if e.message =~ /no cipher match/ - pend "ECDH cipher not supported." - else - raise e - end - ensure - th.join if th - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - end - end - end - - def test_ecdh_curves - sock1, sock2 = tcp_pair - - ctx1 = OpenSSL::SSL::SSLContext.new - begin - ctx1.ciphers = "ECDH" - rescue OpenSSL::SSL::SSLError - pend "ECDH is not enabled in this OpenSSL" if $!.message =~ /no cipher match/ - raise - end - ctx1.ecdh_curves = "P-384:P-521" - ctx1.security_level = 0 - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.ciphers = "ECDH" - ctx2.ecdh_curves = "P-256:P-384" - ctx2.security_level = 0 - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - - th = Thread.new { s1.accept } - s2.connect - - assert s2.cipher[0].start_with?("AECDH"), "AECDH should be used" - if s2.respond_to?(:tmp_key) - assert_equal "secp384r1", s2.tmp_key.group.curve_name - end - ensure - th.join if th - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - end - def test_connect_accept_nonblock_no_exception ctx2 = OpenSSL::SSL::SSLContext.new ctx2.ciphers = "ADH" diff --git a/test/openssl/test_pkcs12.rb b/test/openssl/test_pkcs12.rb index 7ab501c480..4f2544dfee 100644 --- a/test/openssl/test_pkcs12.rb +++ b/test/openssl/test_pkcs12.rb @@ -180,6 +180,112 @@ Li8JsX5yIiuVYaBg/6ha3tOg4TCa5K/3r3tVliRZ2Es= end end + def test_new_with_one_key_and_one_cert + # generated with: + # openssl version #=> OpenSSL 1.0.2h 3 May 2016 + # openssl pkcs12 -in <@mycert> -inkey -export -out + str = <<~EOF.unpack("m").first +MIIGQQIBAzCCBgcGCSqGSIb3DQEHAaCCBfgEggX0MIIF8DCCAu8GCSqGSIb3DQEH +BqCCAuAwggLcAgEAMIIC1QYJKoZIhvcNAQcBMBwGCiqGSIb3DQEMAQYwDgQIeZPM +Rh6KiXgCAggAgIICqL6O+LCZmBzdIg6mozPF3FpY0hVbWHvTNMiDHieW3CrAanhN +YCH2/wHqH8WpFpEWwF0qEEXAWjHsIlYB4Cfqo6b7XpuZe5eVESsjNTOTMF1JCUJj +A6iNefXmCFLync1JK5LUodRDhTlKLU1WPK20X9X4vuEwHn8wt5RUb8P0E+Xh6rpS +XC4LkZKT45zF3cJa/n5+dW65ohVGNVnF9D1bCNEKHMOllK1V9omutQ9slW88hpga +LGiFsJoFOb/ESGb78KO+bd6zbX1MdKdBV+WD6t1uF/cgU65y+2A4nXs1urda+MJ7 +7iVqiB7Vnc9cANTbAkTSGNyoUDVM/NZde782/8IvddLAzUZ2EftoRDke6PvuBOVL +ljBhNWmdamrtBqzuzVZCRdWq44KZkF2Xoc9asepwIkdVmntzQF7f1Z+Ta5yg6HFp +xnr7CuM+MlHEShXkMgYtHnwAq10fDMSXIvjhi/AA5XUAusDO3D+hbtcRDcJ4uUes +dm5dhQE2qJ02Ysn4aH3o1F3RYNOzrxejHJwl0D2TCE8Ww2X342xib57+z9u03ufj +jswhiMKxy67f1LhUMq3XrT3uV6kCVXk/KUOUPcXPlPVNA5JmZeFhMp6GrtB5xJJ9 +wwBZD8UL5A2U2Mxi2OZsdUBv8eo3jnjZ284aFpt+mCjIHrLW5O0jwY8OCwSlYUoY +IY00wlabX0s82kBcIQNZbC1RSV2267ro/7A0MClc8YQ/zWN0FKY6apgtUkHJI1cL +1dc77mhnjETjwW94iLMDFy4zQfVu7IfCBqOBzygRNnqqUG66UhTs1xFnWM0mWXl/ +Zh9+AMpbRLIPaKCktIjl5juzzm+KEgkhD+707XRCFIGUYGP5bSHzGaz8PK9hj0u1 +E2SpZHUvYOcawmxtA7pmpSxl5uQjMIIC+QYJKoZIhvcNAQcBoIIC6gSCAuYwggLi +MIIC3gYLKoZIhvcNAQwKAQKgggKmMIICojAcBgoqhkiG9w0BDAEDMA4ECKB338m8 +qSzHAgIIAASCAoACFhJeqA3xx+s1qIH6udNQYY5hAL6oz7SXoGwFhDiceSyJjmAD +Dby9XWM0bPl1Gj5nqdsuI/lAM++fJeoETk+rxw8q6Ofk2zUaRRE39qgpwBwSk44o +0SAFJ6bzHpc5CFh6sZmDaUX5Lm9GtjnGFmmsPTSJT5an5JuJ9WczGBEd0nSBQhJq +xHbTGZiN8i3SXcIH531Sub+CBIFWy5lyCKgDYh/kgJFGQAaWUOjLI+7dCEESonXn +F3Jh2uPbnDF9MGJyAFoNgWFhgSpi1cf6AUi87GY4Oyur88ddJ1o0D0Kz2uw8/bpG +s3O4PYnIW5naZ8mozzbnYByEFk7PoTwM7VhoFBfYNtBoAI8+hBnPY/Y71YUojEXf +SeX6QbtkIANfzS1XuFNKElShC3DPQIHpKzaatEsfxHfP+8VOav6zcn4mioao7NHA +x7Dp6R1enFGoQOq4UNjBT8YjnkG5vW8zQHW2dAHLTJBq6x2Fzm/4Pjo/8vM1FiGl +BQdW5vfDeJ/l6NgQm3xR9ka2E2HaDqIcj1zWbN8jy/bHPFJYuF/HH8MBV/ngMIXE +vFEW/ToYv8eif0+EpUtzBsCKD4a7qYYYh87RmEVoQU96q6m+UbhpD2WztYfAPkfo +OSL9j2QHhVczhL7OAgqNeM95pOsjA9YMe7exTeqK31LYnTX8oH8WJD1xGbRSJYgu +SY6PQbumcJkc/TFPn0GeVUpiDdf83SeG50lo/i7UKQi2l1hi5Y51fQhnBnyMr68D +llSZEvSWqfDxBJkBpeg6PIYvkTpEwKRJpVQoM3uYvdqVSSnW6rydqIb+snfOrlhd +f+xCtq9xr+kHeTSqLIDRRAnMfgFRhY3IBlj6MSUwIwYJKoZIhvcNAQkVMRYEFBdb +8XGWehZ6oPj56Pf/uId46M9AMDEwITAJBgUrDgMCGgUABBRvSCB04/f8f13pp2PF +vyl2WuMdEwQIMWFFphPkIUICAggA + EOF + p12 = OpenSSL::PKCS12.new(str, "abc123") + + assert_equal TEST_KEY_RSA1024.to_der, p12.key.to_der + assert_equal @mycert.subject.to_der, p12.certificate.subject.to_der + assert_equal [], Array(p12.ca_certs) + end + + def test_new_with_no_keys + # generated with: + # openssl pkcs12 -in <@mycert> -nokeys -export -out + str = <<~EOF.unpack("m").first +MIIDHAIBAzCCAuIGCSqGSIb3DQEHAaCCAtMEggLPMIICyzCCAscGCSqGSIb3DQEH +BqCCArgwggK0AgEAMIICrQYJKoZIhvcNAQcBMBwGCiqGSIb3DQEMAQYwDgQIX4+W +irqwH40CAggAgIICgOaCyo+5+6IOVoGCCL80c50bkkzAwqdXxvkKExJSdcJz2uMU +0gRrKnZEjL5wrUsN8RwZu8DvgQTEhNEkKsUgM7AWainmN/EnwohIdHZAHpm6WD67 +I9kLGp0/DHrqZrV9P2dLfhXLUSQE8PI0tqZPZ8UEABhizkViw4eISTkrOUN7pGbN +Qtx/oqgitXDuX2polbxYYDwt9vfHZhykHoKgew26SeJyZfeMs/WZ6olEI4cQUAFr +mvYGuC1AxEGTo9ERmU8Pm16j9Hr9PFk50WYe+rnk9oX3wJogQ7XUWS5kYf7XRycd +NDkNiwV/ts94bbuaGZp1YA6I48FXpIc8b5fX7t9tY0umGaWy0bARe1L7o0Y89EPe +lMg25rOM7j3uPtFG8whbSfdETSy57UxzzTcJ6UwexeaK6wb2jqEmj5AOoPLWeaX0 +LyOAszR3v7OPAcjIDYZGdrbb3MZ2f2vo2pdQfu9698BrWhXuM7Odh73RLhJVreNI +aezNOAtPyBlvGiBQBGTzRIYHSLL5Y5aVj2vWLAa7hjm5qTL5C5mFdDIo6TkEMr6I +OsexNQofEGs19kr8nARXDlcbEimk2VsPj4efQC2CEXZNzURsKca82pa62MJ8WosB +DTFd8X06zZZ4nED50vLopZvyW4fyW60lELwOyThAdG8UchoAaz2baqP0K4de44yM +Y5/yPFDu4+GoimipJfbiYviRwbzkBxYW8+958ILh0RtagLbvIGxbpaym9PqGjOzx +ShNXjLK2aAFZsEizQ8kd09quJHU/ogq2cUXdqqhmOqPnUWrJVi/VCoRB3Pv1/lE4 +mrUgr2YZ11rYvBw6g5XvNvFcSc53OKyV7SLn0dwwMTAhMAkGBSsOAwIaBQAEFEWP +1WRQykaoD4uJCpTx/wv0SLLBBAiDKI26LJK7xgICCAA= + EOF + p12 = OpenSSL::PKCS12.new(str, "abc123") + + assert_equal nil, p12.key + assert_equal nil, p12.certificate + assert_equal 1, p12.ca_certs.size + assert_equal @mycert.subject.to_der, p12.ca_certs[0].subject.to_der + end + + def test_new_with_no_certs + # generated with: + # openssl pkcs12 -inkey -nocerts -export -out + str = <<~EOF.unpack("m").first +MIIDJwIBAzCCAu0GCSqGSIb3DQEHAaCCAt4EggLaMIIC1jCCAtIGCSqGSIb3DQEH +AaCCAsMEggK/MIICuzCCArcGCyqGSIb3DQEMCgECoIICpjCCAqIwHAYKKoZIhvcN +AQwBAzAOBAg6AaYnJs84SwICCAAEggKAQzZH+fWSpcQYD1J7PsGSune85A++fLCQ +V7tacp2iv95GJkxwYmfTP176pJdgs00mceB9UJ/u9EX5nD0djdjjQjwo6sgKjY0q +cpVhZw8CMxw7kBD2dhtui0zT8z5hy03LePxsjEKsGiSbeVeeGbSfw/I6AAYbv+Uh +O/YPBGumeHj/D2WKnfsHJLQ9GAV3H6dv5VKYNxjciK7f/JEyZCuUQGIN64QFHDhJ +7fzLqd/ul3FZzJZO6a+dwvcgux09SKVXDRSeFmRCEX4b486iWhJJVspCo9P2KNne +ORrpybr3ZSwxyoICmjyo8gj0OSnEfdx9790Ej1takPqSA1wIdSdBLekbZqB0RBQg +DEuPOsXNo3QFi8ji1vu0WBRJZZSNC2hr5NL6lNR+DKxG8yzDll2j4W4BBIp22mAE +7QRX7kVxu17QJXQhOUac4Dd1qXmzebP8t6xkAxD9L7BWEN5OdiXWwSWGjVjMBneX +nYObi/3UT/aVc5WHMHK2BhCI1bwH51E6yZh06d5m0TQpYGUTWDJdWGBSrp3A+8jN +N2PMQkWBFrXP3smHoTEN4oZC4FWiPsIEyAkQsfKRhcV9lGKl2Xgq54ROTFLnwKoj +Z3zJScnq9qmNzvVZSMmDLkjLyDq0pxRxGKBvgouKkWY7VFFIwwBIJM39iDJ5NbBY +i1AQFTRsRSsZrNVPasCXrIq7bhMoJZb/YZOGBLNyJVqKUoYXhtwsajzSq54VlWft +JxsPayEd4Vi6O9EU1ahnj6qFEZiKFzsicgK2J1Rb8cYagrp0XWjHW0SBn5GVUWCg +GUokSFG/0JTdeYTo/sQuG4qNgJkOolRjpeI48Fciq5VUWLvVdKioXzAxMCEwCQYF +Kw4DAhoFAAQUYAuwVtGD1TdgbFK4Yal2XBgwUR4ECEawsN3rNaa6AgIIAA== + EOF + p12 = OpenSSL::PKCS12.new(str, "abc123") + + assert_equal TEST_KEY_RSA1024.to_der, p12.key.to_der + assert_equal nil, p12.certificate + assert_equal [], Array(p12.ca_certs) + end + def test_dup p12 = OpenSSL::PKCS12.create("pass", "name", TEST_KEY_RSA1024, @mycert) assert_equal p12.to_der, p12.dup.to_der diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb index 1b9548a3e9..0af93a8bc2 100644 --- a/test/openssl/test_ssl.rb +++ b/test/openssl/test_ssl.rb @@ -5,182 +5,96 @@ if defined?(OpenSSL::TestUtils) class OpenSSL::TestSSL < OpenSSL::SSLTestCase - def test_ctx_setup + def test_ctx_options ctx = OpenSSL::SSL::SSLContext.new - assert_equal(ctx.setup, true) - assert_equal(ctx.setup, nil) - end - def test_ctx_setup_invalid - m = OpenSSL::SSL::SSLContext::METHODS.first - assert_raise_with_message(ArgumentError, /null/) { - OpenSSL::SSL::SSLContext.new("#{m}\0") - } - assert_raise_with_message(ArgumentError, /\u{ff33 ff33 ff2c}/) { - OpenSSL::SSL::SSLContext.new("\u{ff33 ff33 ff2c}") - } - end - - def test_options_defaults_to_OP_ALL_on - ctx = OpenSSL::SSL::SSLContext.new - assert_equal(OpenSSL::SSL::OP_ALL, (OpenSSL::SSL::OP_ALL & ctx.options)) - end - - def test_setting_twice - ctx = OpenSSL::SSL::SSLContext.new + assert (OpenSSL::SSL::OP_ALL & ctx.options) == OpenSSL::SSL::OP_ALL, + "OP_ALL is set by default" ctx.options = 4 - assert_equal 4, (ctx.options & OpenSSL::SSL::OP_ALL) - ctx.options = OpenSSL::SSL::OP_ALL - assert_equal OpenSSL::SSL::OP_ALL, (ctx.options & OpenSSL::SSL::OP_ALL) - end - - def test_options_setting_nil_means_all - ctx = OpenSSL::SSL::SSLContext.new + assert_equal 4, ctx.options ctx.options = nil - assert_equal(OpenSSL::SSL::OP_ALL, (OpenSSL::SSL::OP_ALL & ctx.options)) + assert_equal OpenSSL::SSL::OP_ALL, ctx.options + + assert_equal true, ctx.setup + assert_predicate ctx, :frozen? + assert_equal nil, ctx.setup end - def test_setting_options_raises_after_setup - ctx = OpenSSL::SSL::SSLContext.new - options = ctx.options - ctx.setup - assert_raise(RuntimeError) do - ctx.options = options - end + def test_ssl_with_server_cert + ctx_proc = -> ctx { + ctx.cert = @svr_cert + ctx.key = @svr_key + ctx.extra_chain_cert = [@ca_cert] + } + server_proc = -> (ctx, ssl) { + assert_equal @svr_cert.to_der, ssl.cert.to_der + assert_equal nil, ssl.peer_cert + + readwrite_loop(ctx, ssl) + } + start_server(ctx_proc: ctx_proc, server_proc: server_proc) { |server, port| + begin + sock = TCPSocket.new("127.0.0.1", port) + ctx = OpenSSL::SSL::SSLContext.new + ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx) + ssl.connect + + assert_equal sock, ssl.io + assert_equal nil, ssl.cert + assert_equal @svr_cert.to_der, ssl.peer_cert.to_der + assert_equal 2, ssl.peer_cert_chain.size + assert_equal @svr_cert.to_der, ssl.peer_cert_chain[0].to_der + assert_equal @ca_cert.to_der, ssl.peer_cert_chain[1].to_der + ensure + ssl&.close + sock&.close + end + } end - def test_ctx_setup_no_compression - ctx = OpenSSL::SSL::SSLContext.new - ctx.options = OpenSSL::SSL::OP_ALL | OpenSSL::SSL::OP_NO_COMPRESSION - assert_equal(ctx.setup, true) - assert_equal(ctx.setup, nil) - assert_equal(OpenSSL::SSL::OP_NO_COMPRESSION, - ctx.options & OpenSSL::SSL::OP_NO_COMPRESSION) - end if defined?(OpenSSL::SSL::OP_NO_COMPRESSION) - - def test_ctx_setup_with_extra_chain_cert - ctx = OpenSSL::SSL::SSLContext.new - ctx.extra_chain_cert = [@ca_cert, @cli_cert] - assert_equal(ctx.setup, true) - assert_equal(ctx.setup, nil) - end - - def test_not_started_session - pend "non socket argument of SSLSocket.new is not supported on this platform" if /mswin|mingw/ =~ RUBY_PLATFORM - open(__FILE__) do |f| - assert_nil EnvUtil.suppress_warning { OpenSSL::SSL::SSLSocket.new(f).cert } - end - end - - def test_ssl_gets - start_server(OpenSSL::SSL::VERIFY_NONE, true) { |server, port| + def test_sysread_and_syswrite + start_server { |server, port| server_connect(port) { |ssl| - ssl.write "abc\n" - IO.select [ssl] + str = "x" * 100 + "\n" + ssl.syswrite(str) + newstr = ssl.sysread(str.bytesize) + assert_equal(str, newstr) - line = ssl.gets - - assert_equal "abc\n", line - assert_equal Encoding::BINARY, line.encoding + buf = "" + ssl.syswrite(str) + assert_same buf, ssl.sysread(str.size, buf) + assert_equal(str, newstr) } } end - def test_ssl_read_nonblock - start_server(OpenSSL::SSL::VERIFY_NONE, true) { |server, port| - server_connect(port) { |ssl| - assert_raise(IO::WaitReadable) { ssl.read_nonblock(100) } - ssl.write("abc\n") - IO.select [ssl] - assert_equal('a', ssl.read_nonblock(1)) - assert_equal("bc\n", ssl.read_nonblock(100)) - assert_raise(IO::WaitReadable) { ssl.read_nonblock(100) } - } - } - end + def test_sync_close + start_server { |server, port| + begin + sock = TCPSocket.new("127.0.0.1", port) + ssl = OpenSSL::SSL::SSLSocket.new(sock) + ssl.connect + ssl.close + assert_not_predicate sock, :closed? + ensure + sock&.close + end - def test_ssl_sysread_blocking_error - start_server(OpenSSL::SSL::VERIFY_NONE, true) { |server, port| - server_connect(port) { |ssl| - ssl.write("abc\n") - assert_raise(TypeError) { ssl.sysread(4, exception: false) } - buf = '' - assert_raise(ArgumentError) { ssl.sysread(4, buf, exception: false) } - assert_equal '', buf - assert_equal buf.object_id, ssl.sysread(4, buf).object_id - assert_equal "abc\n", buf - } - } - end - - def test_connect_and_close - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| - sock = TCPSocket.new("127.0.0.1", port) - ssl = OpenSSL::SSL::SSLSocket.new(sock) - assert(ssl.connect) - ssl.close - assert(!sock.closed?) - sock.close - - sock = TCPSocket.new("127.0.0.1", port) - ssl = OpenSSL::SSL::SSLSocket.new(sock) - ssl.sync_close = true # !! - assert(ssl.connect) - ssl.close - assert(sock.closed?) - } - end - - def test_read_and_write - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| - server_connect(port) { |ssl| - # syswrite and sysread - ITERATIONS.times{|i| - str = "x" * 100 + "\n" - ssl.syswrite(str) - newstr = '' - newstr << ssl.sysread(str.size - newstr.size) until newstr.size == str.size - assert_equal(str, newstr) - - str = "x" * i * 100 + "\n" - buf = "" - ssl.syswrite(str) - assert_equal(buf.object_id, ssl.sysread(str.size, buf).object_id) - newstr = buf - newstr << ssl.sysread(str.size - newstr.size) until newstr.size == str.size - assert_equal(str, newstr) - } - - # puts and gets - ITERATIONS.times{ - str = "x" * 100 + "\n" - ssl.puts(str) - assert_equal(str, ssl.gets) - - str = "x" * 100 - ssl.puts(str) - assert_equal(str, ssl.gets("\n", 100)) - assert_equal("\n", ssl.gets) - } - - # read and write - ITERATIONS.times{|i| - str = "x" * 100 + "\n" - ssl.write(str) - assert_equal(str, ssl.read(str.size)) - - str = "x" * i * 100 + "\n" - buf = "" - ssl.write(str) - assert_equal(buf.object_id, ssl.read(str.size, buf).object_id) - assert_equal(str, buf) - } - } + begin + sock = TCPSocket.new("127.0.0.1", port) + ssl = OpenSSL::SSL::SSLSocket.new(sock) + ssl.sync_close = true # !! + ssl.connect + ssl.close + assert_predicate sock, :closed? + ensure + sock&.close + end } end def test_copy_stream - start_server(OpenSSL::SSL::VERIFY_NONE, true) do |server, port| + start_server do |server, port| server_connect(port) do |ssl| IO.pipe do |r, w| str = "hello world\n" @@ -195,7 +109,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase def test_client_auth_failure vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT - start_server(vflag, true, :ignore_listener_error => true){|server, port| + start_server(verify_mode: vflag, ignore_listener_error: true) { |server, port| sock = TCPSocket.new("127.0.0.1", port) ssl = OpenSSL::SSL::SSLSocket.new(sock) ssl.sync_close = true @@ -209,7 +123,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase def test_client_auth_success vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT - start_server(vflag, true){|server, port| + start_server(verify_mode: vflag) { |server, port| ctx = OpenSSL::SSL::SSLContext.new ctx.key = @cli_key ctx.cert = @cli_cert @@ -236,7 +150,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase def test_client_auth_public_key vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT - start_server(vflag, true, ignore_listener_error: true) do |server, port| + start_server(verify_mode: vflag, ignore_listener_error: true) do |server, port| assert_raise(ArgumentError) { ctx = OpenSSL::SSL::SSLContext.new ctx.key = @cli_key.public_key @@ -258,7 +172,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase end vflag = OpenSSL::SSL::VERIFY_PEER|OpenSSL::SSL::VERIFY_FAIL_IF_NO_PEER_CERT - start_server(vflag, true, :ctx_proc => ctx_proc){|server, port| + start_server(verify_mode: vflag, ctx_proc: ctx_proc) { |server, port| ctx = OpenSSL::SSL::SSLContext.new client_ca_from_server = nil ctx.client_cert_cb = Proc.new do |sslconn| @@ -271,7 +185,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase def test_read_nonblock_without_session OpenSSL::TestUtils.silent do - start_server(OpenSSL::SSL::VERIFY_NONE, false){|server, port| + start_server(start_immediately: false) { |server, port| sock = TCPSocket.new("127.0.0.1", port) ssl = OpenSSL::SSL::SSLSocket.new(sock) ssl.sync_close = true @@ -288,32 +202,50 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase end def test_starttls - OpenSSL::TestUtils.silent do - start_server(OpenSSL::SSL::VERIFY_NONE, false){|server, port| - sock = TCPSocket.new("127.0.0.1", port) - ssl = OpenSSL::SSL::SSLSocket.new(sock) - ssl.sync_close = true - str = "x" * 1000 + "\n" + server_proc = -> (ctx, ssl) { + begin + while line = ssl.gets + if line =~ /^STARTTLS$/ + ssl.write("x") + ssl.flush + ssl.accept + next + end + ssl.write(line) + end + rescue OpenSSL::SSL::SSLError + rescue IOError + ensure + ssl.close rescue nil + end + } - ITERATIONS.times{ - ssl.puts(str) - assert_equal(str, ssl.gets) - } - starttls(ssl) + EnvUtil.suppress_warning do # read/write on not started session + start_server(start_immediately: false, + server_proc: server_proc) { |server, port| + begin + sock = TCPSocket.new("127.0.0.1", port) + ssl = OpenSSL::SSL::SSLSocket.new(sock) - ITERATIONS.times{ - ssl.puts(str) - assert_equal(str, ssl.gets) - } + ssl.puts "plaintext" + assert_equal "plaintext\n", ssl.gets - ssl.close + ssl.puts("STARTTLS") + ssl.read(1) + ssl.connect + + ssl.puts "over-tls" + assert_equal "over-tls\n", ssl.gets + ensure + ssl&.close + sock&.close + end } end end def test_parallel - GC.start - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| + start_server { |server, port| ssls = [] 10.times{ sock = TCPSocket.new("127.0.0.1", port) @@ -334,7 +266,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase end def test_verify_result - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ignore_listener_error => true){|server, port| + start_server(ignore_listener_error: true) { |server, port| sock = TCPSocket.new("127.0.0.1", port) ctx = OpenSSL::SSL::SSLContext.new ctx.verify_mode = OpenSSL::SSL::VERIFY_PEER @@ -348,7 +280,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase end } - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| + start_server { |server, port| sock = TCPSocket.new("127.0.0.1", port) ctx = OpenSSL::SSL::SSLContext.new ctx.verify_mode = OpenSSL::SSL::VERIFY_PEER @@ -366,7 +298,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase end } - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ignore_listener_error => true){|server, port| + start_server(ignore_listener_error: true) { |server, port| sock = TCPSocket.new("127.0.0.1", port) ctx = OpenSSL::SSL::SSLContext.new ctx.verify_mode = OpenSSL::SSL::VERIFY_PEER @@ -386,7 +318,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase end def test_exception_in_verify_callback_is_ignored - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ignore_listener_error => true){|server, port| + start_server(ignore_listener_error: true) { |server, port| sock = TCPSocket.new("127.0.0.1", port) ctx = OpenSSL::SSL::SSLContext.new ctx.verify_mode = OpenSSL::SSL::VERIFY_PEER @@ -411,34 +343,40 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase def test_sslctx_set_params ctx = OpenSSL::SSL::SSLContext.new ctx.set_params - assert_equal(OpenSSL::SSL::VERIFY_PEER, ctx.verify_mode) - ciphers = ctx.ciphers - ciphers_versions = ciphers.collect{|_, v, _, _| v } - ciphers_names = ciphers.collect{|v, _, _, _| v } - assert(ciphers_names.all?{|v| /A(EC)?DH/ !~ v }) - assert(ciphers_names.all?{|v| /(RC4|MD5|EXP)/ !~ v }) - assert(ciphers_versions.all?{|v| /SSLv2/ !~ v }) + + assert_equal OpenSSL::SSL::VERIFY_PEER, ctx.verify_mode + ciphers_names = ctx.ciphers.collect{|v, _, _, _| v } + assert ciphers_names.all?{|v| /A(EC)?DH/ !~ v }, "anon ciphers are disabled" + assert ciphers_names.all?{|v| /(RC4|MD5|EXP|DES)/ !~ v }, "weak ciphers are disabled" + assert_equal 0, ctx.options & OpenSSL::SSL::OP_DONT_INSERT_EMPTY_FRAGMENTS + if defined?(OpenSSL::SSL::OP_NO_COMPRESSION) # >= 1.0.0 + assert_equal OpenSSL::SSL::OP_NO_COMPRESSION, + ctx.options & OpenSSL::SSL::OP_NO_COMPRESSION + end end def test_post_connect_check_with_anon_ciphers - sslerr = OpenSSL::SSL::SSLError + ctx_proc = -> ctx { + ctx.ciphers = "aNULL" + ctx.security_level = 0 + } - start_server(OpenSSL::SSL::VERIFY_NONE, true, {use_anon_cipher: true}){|server, port| + start_server(ctx_proc: ctx_proc) { |server, port| ctx = OpenSSL::SSL::SSLContext.new ctx.ciphers = "aNULL" ctx.security_level = 0 server_connect(port, ctx) { |ssl| - assert_raise_with_message(sslerr, /anonymous cipher suite/i){ + assert_raise_with_message(OpenSSL::SSL::SSLError, /anonymous cipher suite/i) { ssl.post_connection_check("localhost.localdomain") } } } - end if OpenSSL::ExtConfig::TLS_DH_anon_WITH_AES_256_GCM_SHA384 + end def test_post_connection_check sslerr = OpenSSL::SSL::SSLError - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| + start_server { |server, port| server_connect(port) { |ssl| assert_raise(sslerr){ssl.post_connection_check("localhost.localdomain")} assert_raise(sslerr){ssl.post_connection_check("127.0.0.1")} @@ -461,7 +399,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase ] @svr_cert = issue_cert(@svr, @svr_key, 4, now, now+1800, exts, @ca_cert, @ca_key, OpenSSL::Digest::SHA1.new) - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| + start_server { |server, port| server_connect(port) { |ssl| assert(ssl.post_connection_check("localhost.localdomain")) assert(ssl.post_connection_check("127.0.0.1")) @@ -483,7 +421,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase ] @svr_cert = issue_cert(@svr, @svr_key, 5, now, now+1800, exts, @ca_cert, @ca_key, OpenSSL::Digest::SHA1.new) - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| + start_server { |server, port| server_connect(port) { |ssl| assert(ssl.post_connection_check("localhost.localdomain")) assert_raise(sslerr){ssl.post_connection_check("127.0.0.1")} @@ -685,39 +623,52 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase end end - def test_servername_cb_sets_context_on_the_socket - hostname = 'example.org' - + def test_tlsext_hostname ctx3 = OpenSSL::SSL::SSLContext.new - ctx3.ciphers = "aNULL" + ctx3.ciphers = "ADH" ctx3.tmp_dh_callback = proc { OpenSSL::TestUtils::TEST_KEY_DH1024 } ctx3.security_level = 0 + assert_not_predicate ctx3, :frozen? - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.servername_cb = lambda { |args| ctx3 } + ctx_proc = -> ctx { + ctx.ciphers = "ALL:!aNULL" + ctx.servername_cb = proc { |ssl, servername| + case servername + when "foo.example.com" + ctx3 + when "bar.example.com" + nil + else + raise "unknown hostname" + end + } + } + start_server(ctx_proc: ctx_proc) do |server, port| + ctx = OpenSSL::SSL::SSLContext.new + ctx.ciphers = "ALL" + ctx.security_level = 0 - sock1, sock2 = socketpair + sock = TCPSocket.new("127.0.0.1", port) + begin + ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx) + ssl.hostname = "foo.example.com" + ssl.connect + assert_match /^ADH-/, ssl.cipher[0], "the context returned by servername_cb is used" + assert_predicate ctx3, :frozen? + ensure + sock.close + end - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - - ctx1 = OpenSSL::SSL::SSLContext.new - ctx1.ciphers = "aNULL" - ctx1.security_level = 0 - - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - s1.hostname = hostname - t = Thread.new { s1.connect } - - assert_equal ctx2, s2.context - accepted = s2.accept - assert_equal ctx3, s2.context - assert t.value - ensure - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - accepted.close if accepted.respond_to?(:close) + sock = TCPSocket.new("127.0.0.1", port) + begin + ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx) + ssl.hostname = "bar.example.com" + ssl.connect + assert_not_match /^A(EC)?DH-/, ssl.cipher[0], "the original context is used" + ensure + sock.close + end + end end def test_servername_cb_raises_an_exception_on_unknown_objects @@ -755,148 +706,6 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase sock2.close if sock2 end - def test_servername_cb_calls_setup_on_returned_ctx - hostname = 'example.org' - - ctx3 = OpenSSL::SSL::SSLContext.new - ctx3.ciphers = "aNULL" - ctx3.tmp_dh_callback = proc { OpenSSL::TestUtils::TEST_KEY_DH1024 } - ctx3.security_level = 0 - assert_not_predicate ctx3, :frozen? - - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.servername_cb = lambda { |args| ctx3 } - - sock1, sock2 = socketpair - - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - - ctx1 = OpenSSL::SSL::SSLContext.new - ctx1.ciphers = "aNULL" - ctx1.security_level = 0 - - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - s1.hostname = hostname - t = Thread.new { s1.connect } - - accepted = s2.accept - assert t.value - assert_predicate ctx3, :frozen? - ensure - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - accepted.close if accepted.respond_to?(:close) - end - - def test_servername_cb_can_return_nil - hostname = 'example.org' - - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.ciphers = "aNULL" - ctx2.tmp_dh_callback = proc { OpenSSL::TestUtils::TEST_KEY_DH1024 } - ctx2.security_level = 0 - ctx2.servername_cb = lambda { |args| nil } - - sock1, sock2 = socketpair - - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - - ctx1 = OpenSSL::SSL::SSLContext.new - ctx1.ciphers = "aNULL" - ctx1.security_level = 0 - - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - s1.hostname = hostname - t = Thread.new { s1.connect } - - accepted = s2.accept - assert t.value - ensure - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - accepted.close if accepted.respond_to?(:close) - end - - def test_servername_cb - lambda_called = nil - cb_socket = nil - hostname = 'example.org' - - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.ciphers = "aNULL" - ctx2.tmp_dh_callback = proc { OpenSSL::TestUtils::TEST_KEY_DH1024 } - ctx2.security_level = 0 - ctx2.servername_cb = lambda do |args| - cb_socket = args[0] - lambda_called = args[1] - ctx2 - end - - sock1, sock2 = socketpair - - s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) - - ctx1 = OpenSSL::SSL::SSLContext.new - ctx1.ciphers = "aNULL" - ctx1.security_level = 0 - - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - s1.hostname = hostname - t = Thread.new { s1.connect } - - accepted = s2.accept - assert t.value - assert_equal hostname, lambda_called - assert_equal s2, cb_socket - ensure - s1.close if s1 - s2.close if s2 - sock1.close if sock1 - sock2.close if sock2 - accepted.close if accepted.respond_to?(:close) - end - - def test_tlsext_hostname - return unless OpenSSL::SSL::SSLSocket.instance_methods.include?(:hostname) - - ctx_proc = Proc.new do |ctx, ssl| - foo_ctx = OpenSSL::SSL::SSLContext.new - - ctx.servername_cb = Proc.new do |ssl2, hostname| - case hostname - when 'foo.example.com' - foo_ctx - when 'bar.example.com' - nil - else - raise "unknown hostname #{hostname.inspect}" - end - end - end - - server_proc = Proc.new do |ctx, ssl| - readwrite_loop(ctx, ssl) - end - - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ctx_proc => ctx_proc, :server_proc => server_proc) do |server, port| - 2.times do |i| - ctx = OpenSSL::SSL::SSLContext.new - # disable RFC4507 support - ctx.options = OpenSSL::SSL::OP_NO_TICKET - server_connect(port, ctx) { |ssl| - ssl.hostname = (i & 1 == 0) ? 'foo.example.com' : 'bar.example.com' - str = "x" * 100 + "\n" - ssl.puts(str) - assert_equal(str, ssl.gets) - } - end - end - end - def test_verify_hostname_on_connect ctx_proc = proc { |ctx| now = Time.now @@ -910,8 +719,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase ctx.key = @svr_key } - start_server(OpenSSL::SSL::VERIFY_NONE, true, ctx_proc: ctx_proc, - ignore_listener_error: true) do |svr, port| + start_server(ctx_proc: ctx_proc, ignore_listener_error: true) do |server, port| ctx = OpenSSL::SSL::SSLContext.new ctx.verify_hostname = true ctx.cert_store = OpenSSL::X509::Store.new @@ -960,7 +768,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase assert_equal(num_written, raw_size) ssl.close } - start_server(OpenSSL::SSL::VERIFY_NONE, true, :server_proc => server_proc){|server, port| + start_server(server_proc: server_proc) { |server, port| server_connect(port) { |ssl| str = auml * i num_written = ssl.write(str) @@ -976,7 +784,7 @@ class OpenSSL::TestSSL < OpenSSL::SSLTestCase # But it also degrades gracefully, so keep it ctx.options = OpenSSL::SSL::OP_ALL } - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ctx_proc => ctx_proc){|server, port| + start_server(ctx_proc: ctx_proc) { |server, port| server_connect(port) { |ssl| ssl.puts('hello') assert_equal("hello\n", ssl.gets) @@ -1123,9 +931,10 @@ if OpenSSL::OPENSSL_VERSION_NUMBER >= 0x10002000 ssl2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) t = Thread.new { - assert_handshake_error { ssl2.connect } + ssl2.connect_nonblock(exception: false) } - assert_raise(TypeError) { ssl1.accept } + assert_raise_with_message(TypeError, /nil/) { ssl1.accept } + t.join ensure sock1&.close sock2&.close @@ -1207,7 +1016,7 @@ end def test_invalid_shutdown_by_gc assert_nothing_raised { - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| + start_server { |server, port| 10.times { sock = TCPSocket.new("127.0.0.1", port) ssl = OpenSSL::SSL::SSLSocket.new(sock) @@ -1220,7 +1029,7 @@ end end def test_close_after_socket_close - start_server(OpenSSL::SSL::VERIFY_NONE, true){|server, port| + start_server { |server, port| sock = TCPSocket.new("127.0.0.1", port) ssl = OpenSSL::SSL::SSLSocket.new(sock) ssl.sync_close = true @@ -1274,7 +1083,7 @@ end 'AES128-SHA' => nil } conf_proc = Proc.new { |ctx| ctx.ciphers = 'ALL' } - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ctx_proc => conf_proc) do |server, port| + start_server(ctx_proc: conf_proc) do |server, port| ciphers.each do |cipher, ephemeral| ctx = OpenSSL::SSL::SSLContext.new begin @@ -1294,6 +1103,123 @@ end end end + def test_dh_callback + called = false + ctx_proc = -> ctx { + ctx.ciphers = "DH:!NULL" + ctx.tmp_dh_callback = ->(*args) { + called = true + OpenSSL::TestUtils::TEST_KEY_DH1024 + } + } + start_server(ctx_proc: ctx_proc) do |server, port| + server_connect(port) { |ssl| + assert called, "dh callback should be called" + if ssl.respond_to?(:tmp_key) + assert_equal OpenSSL::TestUtils::TEST_KEY_DH1024.to_der, ssl.tmp_key.to_der + end + } + end + end + + def test_connect_works_when_setting_dh_callback_to_nil + ctx_proc = -> ctx { + ctx.ciphers = "DH:!NULL" # use DH + ctx.tmp_dh_callback = nil + } + start_server(ctx_proc: ctx_proc) do |server, port| + EnvUtil.suppress_warning { # uses default callback + assert_nothing_raised { + server_connect(port) { } + } + } + end + end + + def test_ecdh_callback + return unless OpenSSL::SSL::SSLContext.instance_methods.include?(:tmp_ecdh_callback) + EnvUtil.suppress_warning do # tmp_ecdh_callback is deprecated (2016-05) + begin + called = false + ctx2 = OpenSSL::SSL::SSLContext.new + ctx2.ciphers = "ECDH" + # OpenSSL 1.1.0 doesn't have tmp_ecdh_callback so this shouldn't be required + ctx2.security_level = 0 + ctx2.tmp_ecdh_callback = ->(*args) { + called = true + OpenSSL::PKey::EC.new "prime256v1" + } + + sock1, sock2 = socketpair + + s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) + ctx1 = OpenSSL::SSL::SSLContext.new + ctx1.ciphers = "ECDH" + ctx1.security_level = 0 + + s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) + th = Thread.new do + s1.connect + end + + s2.accept + assert called, 'ecdh callback should be called' + rescue OpenSSL::SSL::SSLError => e + if e.message =~ /no cipher match/ + pend "ECDH cipher not supported." + else + raise e + end + ensure + th.join if th + s1.close if s1 + s2.close if s2 + sock1.close if sock1 + sock2.close if sock2 + end + end + end + + def test_ecdh_curves + ctx_proc = -> ctx { + begin + ctx.ciphers = "ECDH:!NULL" + rescue OpenSSL::SSL::SSLError + pend "ECDH is not enabled in this OpenSSL" if $!.message =~ /no cipher match/ + raise + end + ctx.ecdh_curves = "P-384:P-521" + } + start_server(ctx_proc: ctx_proc, ignore_listener_error: true) do |server, port| + ctx = OpenSSL::SSL::SSLContext.new + ctx.ecdh_curves = "P-256:P-384" # disable P-521 for OpenSSL >= 1.0.2 + + server_connect(port, ctx) { |ssl| + assert ssl.cipher[0].start_with?("ECDH"), "ECDH should be used" + if ssl.respond_to?(:tmp_key) + assert_equal "secp384r1", ssl.tmp_key.group.curve_name + end + } + + if OpenSSL::OPENSSL_VERSION_NUMBER >= 0x10002000 && + !OpenSSL::OPENSSL_VERSION.include?("LibreSSL") + ctx = OpenSSL::SSL::SSLContext.new + ctx.ecdh_curves = "P-256" + + assert_raise(OpenSSL::SSL::SSLError) { + server_connect(port, ctx) { } + } + + ctx = OpenSSL::SSL::SSLContext.new + ctx.ecdh_curves = "P-521:P-384" + + server_connect(port, ctx) { |ssl| + assert_equal "secp521r1", ssl.tmp_key.group.curve_name + } + end + end + end + def test_security_level ctx = OpenSSL::SSL::SSLContext.new begin @@ -1325,17 +1251,16 @@ end private - def start_server_version(version, ctx_proc=nil, server_proc=nil, &blk) + def start_server_version(version, ctx_proc = nil, + server_proc = method(:readwrite_loop), &blk) ctx_wrap = Proc.new { |ctx| ctx.ssl_version = version ctx_proc.call(ctx) if ctx_proc } start_server( - OpenSSL::SSL::VERIFY_NONE, - true, - :ctx_proc => ctx_wrap, - :server_proc => server_proc, - :ignore_listener_error => true, + ctx_proc: ctx_wrap, + server_proc: server_proc, + ignore_listener_error: true, &blk ) end diff --git a/test/openssl/test_ssl_session.rb b/test/openssl/test_ssl_session.rb index 9bcec10165..b2643edd8c 100644 --- a/test/openssl/test_ssl_session.rb +++ b/test/openssl/test_ssl_session.rb @@ -27,7 +27,7 @@ tddwpBAEDjcwMzA5NTYzMTU1MzAwpQMCARM= -----END SSL SESSION PARAMETERS----- SESSION - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ignore_listener_error => true) { |_, port| + start_server(ignore_listener_error: true) { |_, port| ctx = OpenSSL::SSL::SSLContext.new ctx.session_cache_mode = OpenSSL::SSL::SSLContext::SESSION_CACHE_CLIENT ctx.session_id_context = self.object_id.to_s @@ -46,7 +46,7 @@ tddwpBAEDjcwMzA5NTYzMTU1MzAwpQMCARM= def test_session Timeout.timeout(5) do - start_server(OpenSSL::SSL::VERIFY_NONE, true) do |server, port| + start_server do |server, port| sock = TCPSocket.new("127.0.0.1", port) ctx = OpenSSL::SSL::SSLContext.new("TLSv1") ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx) @@ -154,7 +154,7 @@ __EOS__ def test_client_session last_session = nil - start_server(OpenSSL::SSL::VERIFY_NONE, true) do |server, port| + start_server do |server, port| 2.times do sock = TCPSocket.new("127.0.0.1", port) # Debian's openssl 0.9.8g-13 failed at assert(ssl.session_reused?), @@ -237,7 +237,7 @@ __EOS__ end first_session = nil - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ctx_proc => ctx_proc, :server_proc => server_proc) do |server, port| + start_server(ctx_proc: ctx_proc, server_proc: server_proc) do |server, port| 10.times do |i| sock = TCPSocket.new("127.0.0.1", port) ctx = OpenSSL::SSL::SSLContext.new @@ -285,7 +285,7 @@ __EOS__ # any resulting value is OK (ignored) } - start_server(OpenSSL::SSL::VERIFY_NONE, true) do |server, port| + start_server do |server, port| sock = TCPSocket.new("127.0.0.1", port) begin ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx) @@ -344,7 +344,7 @@ __EOS__ c.session_cache_stats readwrite_loop(c, ssl) } - start_server(OpenSSL::SSL::VERIFY_NONE, true, :ctx_proc => ctx_proc, :server_proc => server_proc) do |server, port| + start_server(ctx_proc: ctx_proc, server_proc: server_proc) do |server, port| last_client_session = nil 3.times do sock = TCPSocket.new("127.0.0.1", port) diff --git a/test/openssl/utils.rb b/test/openssl/utils.rb index 2288a26035..6f3a3c6d1c 100644 --- a/test/openssl/utils.rb +++ b/test/openssl/utils.rb @@ -241,10 +241,6 @@ AQjjxMXhwULlmuR/K+WwlaZPiLIBYalLAZQ7ZbOPeVkJ8ePao0eLAgEC def readwrite_loop(ctx, ssl) while line = ssl.gets - if line =~ /^STARTTLS$/ - ssl.accept - next - end ssl.write(line) end rescue OpenSSL::SSL::SSLError @@ -281,22 +277,15 @@ AQjjxMXhwULlmuR/K+WwlaZPiLIBYalLAZQ7ZbOPeVkJ8ePao0eLAgEC end end - def start_server(verify_mode, start_immediately, args = {}, &block) + def start_server(verify_mode: OpenSSL::SSL::VERIFY_NONE, start_immediately: true, + ctx_proc: nil, server_proc: method(:readwrite_loop), + ignore_listener_error: false, &block) IO.pipe {|stop_pipe_r, stop_pipe_w| - ctx_proc = args[:ctx_proc] - server_proc = args[:server_proc] - ignore_listener_error = args.fetch(:ignore_listener_error, false) - use_anon_cipher = args.fetch(:use_anon_cipher, false) - server_proc ||= method(:readwrite_loop) - store = OpenSSL::X509::Store.new store.add_cert(@ca_cert) store.purpose = OpenSSL::X509::PURPOSE_SSL_CLIENT ctx = OpenSSL::SSL::SSLContext.new - ctx.ciphers = "ADH-AES256-GCM-SHA384" if use_anon_cipher - ctx.security_level = 0 if use_anon_cipher ctx.cert_store = store - #ctx.extra_chain_cert = [ ca_cert ] ctx.cert = @svr_cert ctx.key = @svr_key ctx.tmp_dh_callback = proc { OpenSSL::TestUtils::TEST_KEY_DH1024 } @@ -341,13 +330,6 @@ AQjjxMXhwULlmuR/K+WwlaZPiLIBYalLAZQ7ZbOPeVkJ8ePao0eLAgEC end } end - - def starttls(ssl) - ssl.puts("STARTTLS") - sleep 1 # When this line is eliminated, process on Cygwin blocks - # forever at ssl.connect. But I don't know why it does. - ssl.connect - end end class OpenSSL::PKeyTestCase < OpenSSL::TestCase