diff --git a/test/openssl/ssl_server.rb b/test/openssl/ssl_server.rb index 53e520379b..ce3c6132cd 100644 --- a/test/openssl/ssl_server.rb +++ b/test/openssl/ssl_server.rb @@ -49,30 +49,29 @@ Socket.do_not_reverse_lookup = true tcps = TCPServer.new("0.0.0.0", port) ssls = OpenSSL::SSL::SSLServer.new(tcps, ctx) ssls.start_immediately = start_immediately -ssock = nil Thread.start{ - while line = $stdin.gets - if /STARTTLS/ =~ line - ssock && ssock.accept - end + while true + $stdin.gets || exit end - exit } $stdout.sync = true $stdout.puts Process.pid loop do - s = ssls.accept - ssock = s + ssl = ssls.accept Thread.start{ q = Queue.new - th = Thread.start{ s.write(q.shift) while true } - while line = s.gets + th = Thread.start{ ssl.write(q.shift) while true } + while line = ssl.gets + if line =~ /^STARTTLS$/ + ssl.accept + next + end q.push(line) end - th.kill - s.close unless s.closed? + th.kill if q.empty? + ssl.close } end diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb index e0517f9953..3ca25cbfe1 100644 --- a/test/openssl/test_ssl.rb +++ b/test/openssl/test_ssl.rb @@ -16,7 +16,7 @@ class OpenSSL::TestSSL < Test::Unit::TestCase ) SSL_SERVER = File.join(File.dirname(__FILE__), "ssl_server.rb") PORT = 20443 - ITERATIONS = ($0 == __FILE__) ? 100 : 10 + ITERATIONS = ($0 == __FILE__) ? 100 : 10 def setup @ca_key = OpenSSL::TestUtils::TEST_KEY_RSA2048 @@ -65,15 +65,23 @@ class OpenSSL::TestSSL < Test::Unit::TestCase server.write(@ca_cert.to_pem) server.write(@svr_cert.to_pem) server.write(@svr_key.to_pem) - def server.starttls; self.puts("STARTTLS") end pid = Integer(server.gets) $stderr.printf("%s started: pid=%d\n", SSL_SERVER, pid) if $DEBUG block.call(server) ensure - server.close if server + if server + server.close_write + Process.kill(:TERM, pid) rescue nil + Process.waitpid(pid) + end end end + def starttls(ssl) + ssl.puts("STARTTLS") + ssl.connect + end + def test_connect_and_close start_server(PORT, OpenSSL::SSL::VERIFY_NONE, true){ sock = TCPSocket.new("127.0.0.1", PORT) @@ -148,8 +156,7 @@ class OpenSSL::TestSSL < Test::Unit::TestCase assert_equal(str, ssl.gets) } - s.starttls - ssl.connect + starttls(ssl) ITERATIONS.times{ ssl.puts(str) @@ -177,6 +184,7 @@ class OpenSSL::TestSSL < Test::Unit::TestCase assert_equal(str, ssl.gets) } } + ssls.each{|ssl| ssl.close } } end end