mirror of
https://github.com/ruby/ruby.git
synced 2022-11-09 12:17:21 -05:00
[ruby/zlib] Synchronize access to zstream to prevent segfault in multithreaded use
I'm not sure whether this handles all multithreaded use cases, but this handles the example that crashes almost immediately and does 10,000,000 total deflates using 100 separate threads. To prevent the tests from taking forever, the committed test for this uses only 10,000 deflates across 10 separate threads, which still causes a segfault in the previous implementation almost immediately. Fixes [Bug #17803] https://github.com/ruby/zlib/commit/4b1023b3f2
This commit is contained in:
parent
218c3b2548
commit
b3d62a77d9
2 changed files with 93 additions and 1 deletions
|
@ -546,6 +546,7 @@ struct zstream {
|
|||
unsigned long flags;
|
||||
VALUE buf;
|
||||
VALUE input;
|
||||
VALUE mutex;
|
||||
z_stream stream;
|
||||
const struct zstream_funcs {
|
||||
int (*reset)(z_streamp);
|
||||
|
@ -621,6 +622,7 @@ zstream_init(struct zstream *z, const struct zstream_funcs *func)
|
|||
z->flags = 0;
|
||||
z->buf = Qnil;
|
||||
z->input = Qnil;
|
||||
z->mutex = rb_mutex_new();
|
||||
z->stream.zalloc = zlib_mem_alloc;
|
||||
z->stream.zfree = zlib_mem_free;
|
||||
z->stream.opaque = Z_NULL;
|
||||
|
@ -652,7 +654,9 @@ zstream_expand_buffer(struct zstream *z)
|
|||
rb_obj_reveal(z->buf, rb_cString);
|
||||
}
|
||||
|
||||
rb_mutex_unlock(z->mutex);
|
||||
rb_protect(rb_yield, z->buf, &state);
|
||||
rb_mutex_lock(z->mutex);
|
||||
|
||||
if (ZSTREAM_REUSE_BUFFER_P(z)) {
|
||||
rb_str_modify(z->buf);
|
||||
|
@ -1054,7 +1058,7 @@ zstream_unblock_func(void *ptr)
|
|||
}
|
||||
|
||||
static void
|
||||
zstream_run(struct zstream *z, Bytef *src, long len, int flush)
|
||||
zstream_run0(struct zstream *z, Bytef *src, long len, int flush)
|
||||
{
|
||||
struct zstream_run_args args;
|
||||
int err;
|
||||
|
@ -1138,6 +1142,32 @@ loop:
|
|||
rb_jump_tag(args.jump_state);
|
||||
}
|
||||
|
||||
struct zstream_run_synchronized_args {
|
||||
struct zstream *z;
|
||||
Bytef *src;
|
||||
long len;
|
||||
int flush;
|
||||
};
|
||||
|
||||
static VALUE
|
||||
zstream_run_synchronized(VALUE value_arg)
|
||||
{
|
||||
struct zstream_run_synchronized_args *run_args = (struct zstream_run_synchronized_args *)value_arg;
|
||||
zstream_run0(run_args->z, run_args->src, run_args->len, run_args->flush);
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static void
|
||||
zstream_run(struct zstream *z, Bytef *src, long len, int flush)
|
||||
{
|
||||
struct zstream_run_synchronized_args run_args;
|
||||
run_args.z = z;
|
||||
run_args.src = src;
|
||||
run_args.len = len;
|
||||
run_args.flush = flush;
|
||||
rb_mutex_synchronize(z->mutex, zstream_run_synchronized, (VALUE)&run_args);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
zstream_sync(struct zstream *z, Bytef *src, long len)
|
||||
{
|
||||
|
@ -1183,6 +1213,7 @@ zstream_mark(void *p)
|
|||
struct zstream *z = p;
|
||||
rb_gc_mark(z->buf);
|
||||
rb_gc_mark(z->input);
|
||||
rb_gc_mark(z->mutex);
|
||||
}
|
||||
|
||||
static void
|
||||
|
|
|
@ -4,6 +4,7 @@ require 'test/unit'
|
|||
require 'stringio'
|
||||
require 'tempfile'
|
||||
require 'tmpdir'
|
||||
require 'securerandom'
|
||||
|
||||
begin
|
||||
require 'zlib'
|
||||
|
@ -503,6 +504,66 @@ if defined? Zlib
|
|||
assert_raise(Zlib::StreamError) { z.set_dictionary("foo") }
|
||||
z.close
|
||||
end
|
||||
|
||||
def test_multithread_deflate
|
||||
zd = Zlib::Deflate.new
|
||||
|
||||
s = "x" * 10000
|
||||
(0...10).map do |x|
|
||||
Thread.new do
|
||||
1000.times { zd.deflate(s) }
|
||||
end
|
||||
end.each do |th|
|
||||
th.join
|
||||
end
|
||||
ensure
|
||||
zd&.finish
|
||||
zd&.close
|
||||
end
|
||||
|
||||
def test_multithread_inflate
|
||||
zi = Zlib::Inflate.new
|
||||
|
||||
s = Zlib.deflate("x" * 10000)
|
||||
(0...10).map do |x|
|
||||
Thread.new do
|
||||
1000.times { zi.inflate(s) }
|
||||
end
|
||||
end.each do |th|
|
||||
th.join
|
||||
end
|
||||
ensure
|
||||
zi&.finish
|
||||
zi&.close
|
||||
end
|
||||
|
||||
def test_recursive_deflate
|
||||
zd = Zlib::Deflate.new
|
||||
|
||||
s = SecureRandom.random_bytes(1024**2)
|
||||
assert_raise(Zlib::BufError) do
|
||||
zd.deflate(s) do
|
||||
zd.deflate(s)
|
||||
end
|
||||
end
|
||||
ensure
|
||||
zd&.finish
|
||||
zd&.close
|
||||
end
|
||||
|
||||
def test_recursive_inflate
|
||||
zi = Zlib::Inflate.new
|
||||
|
||||
s = Zlib.deflate(SecureRandom.random_bytes(1024**2))
|
||||
|
||||
assert_raise(Zlib::DataError) do
|
||||
zi.inflate(s) do
|
||||
zi.inflate(s)
|
||||
end
|
||||
end
|
||||
ensure
|
||||
zi&.close
|
||||
end
|
||||
end
|
||||
|
||||
class TestZlibGzipFile < Test::Unit::TestCase
|
||||
|
|
Loading…
Add table
Reference in a new issue