diff --git a/ChangeLog b/ChangeLog index 0627ef6669..492aee213d 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,8 @@ +Sun Apr 20 04:45:13 2008 Tanaka Akira + + * io.c (copy_stream_body): use readpartial and write method for + non-IOs such as StringIO and ARGF. + Fri Apr 18 20:57:33 2008 Yusuke Endoh * test/ruby/test_array.rb: add tests to achieve over 95% test coverage diff --git a/io.c b/io.c index f3a7c597dc..fa8246cf08 100644 --- a/io.c +++ b/io.c @@ -124,7 +124,7 @@ VALUE rb_default_rs; static VALUE argf; -static ID id_write, id_read, id_getc, id_flush, id_encode; +static ID id_write, id_read, id_getc, id_flush, id_encode, id_readpartial; struct timeval rb_time_interval(VALUE); @@ -6250,10 +6250,11 @@ rb_io_s_read(int argc, VALUE *argv, VALUE io) struct copy_stream_struct { VALUE src; VALUE dst; + off_t copy_length; /* (off_t)-1 if not specified */ + off_t src_offset; /* (off_t)-1 if not specified */ + int src_fd; int dst_fd; - off_t copy_length; - off_t src_offset; int close_src; int close_dst; off_t total; @@ -6566,6 +6567,49 @@ finish: return Qnil; } +static VALUE +copy_stream_fallback_body(VALUE arg) +{ + struct copy_stream_struct *stp = (struct copy_stream_struct *)arg; + const int buflen = 16*1024; + VALUE n; + VALUE buf = rb_str_buf_new(buflen); + if (stp->copy_length == (off_t)-1) { + while (1) { + rb_funcall(stp->src, id_readpartial, + 2, INT2FIX(buflen), buf); + n = rb_io_write(stp->dst, buf); + stp->total += NUM2LONG(n); + } + } + else { + long rest = stp->copy_length; + while (0 < rest) { + long l = buflen < rest ? buflen : rest; + long numwrote; + rb_funcall(stp->src, id_readpartial, + 2, INT2FIX(l), buf); + n = rb_io_write(stp->dst, buf); + numwrote = NUM2LONG(n); + stp->total += numwrote; + rest -= numwrote; + } + } + return Qnil; +} + +static VALUE +copy_stream_fallback(struct copy_stream_struct *stp) +{ + if (stp->src_offset != (off_t)-1) { + rb_raise(rb_eArgError, "cannot specify src_offset"); + } + rb_rescue2(copy_stream_fallback_body, (VALUE)stp, + (VALUE (*) (ANYARGS))0, (VALUE)0, + rb_eEOFError, (VALUE)0); + return Qnil; +} + static VALUE copy_stream_body(VALUE arg) { @@ -6577,6 +6621,21 @@ copy_stream_body(VALUE arg) stp->th = GET_THREAD(); + stp->total = 0; + + if (stp->src == argf || + stp->dst == argf || + !(TYPE(stp->src) == T_FILE || + rb_respond_to(stp->src, rb_intern("to_io")) || + TYPE(stp->src) == T_STRING || + rb_respond_to(stp->src, rb_intern("to_path"))) || + !(TYPE(stp->dst) == T_FILE || + rb_respond_to(stp->dst, rb_intern("to_io")) || + TYPE(stp->dst) == T_STRING || + rb_respond_to(stp->dst, rb_intern("to_path")))) { + return copy_stream_fallback(stp); + } + src_io = rb_check_convert_type(stp->src, T_FILE, "IO", "to_io"); if (!NIL_P(src_io)) { GetOpenFile(src_io, src_fptr); @@ -6616,8 +6675,6 @@ copy_stream_body(VALUE arg) } stp->dst_fd = dst_fd; - stp->total = 0; - if (src_fptr && dst_fptr && src_fptr->rbuf_len && dst_fptr->wbuf_len) { long len = src_fptr->rbuf_len; VALUE str; @@ -6708,6 +6765,9 @@ rb_io_s_copy_stream(int argc, VALUE *argv, VALUE io) rb_scan_args(argc, argv, "22", &src, &dst, &length, &src_offset); + st.src = src; + st.dst = dst; + if (NIL_P(length)) st.copy_length = (off_t)-1; else @@ -6718,9 +6778,6 @@ rb_io_s_copy_stream(int argc, VALUE *argv, VALUE io) else st.src_offset = NUM2OFFT(src_offset); - st.src = src; - st.dst = dst; - rb_ensure(copy_stream_body, (VALUE)&st, copy_stream_finalize, (VALUE)&st); return OFFT2NUM(st.total); @@ -7344,6 +7401,7 @@ Init_IO(void) id_getc = rb_intern("getc"); id_flush = rb_intern("flush"); id_encode = rb_intern("encode"); + id_readpartial = rb_intern("readpartial"); rb_define_global_function("syscall", rb_f_syscall, -1); diff --git a/test/ruby/test_io.rb b/test/ruby/test_io.rb index 0cb8a775e2..d2292446fd 100644 --- a/test/ruby/test_io.rb +++ b/test/ruby/test_io.rb @@ -2,6 +2,7 @@ require 'test/unit' require 'tmpdir' require 'io/nonblock' require 'socket' +require 'stringio' class TestIO < Test::Unit::TestCase def test_gets_rs @@ -393,8 +394,33 @@ class TestIO < Test::Unit::TestCase result = t.value assert_equal(megacontent, result) } - - } end + + def test_copy_stream_strio + src = StringIO.new("abcd") + dst = StringIO.new + ret = IO.copy_stream(src, dst) + assert_equal(4, ret) + assert_equal("abcd", dst.string) + assert_equal(4, src.pos) + end + + def test_copy_stream_strio_len + src = StringIO.new("abcd") + dst = StringIO.new + ret = IO.copy_stream(src, dst, 3) + assert_equal(3, ret) + assert_equal("abc", dst.string) + assert_equal(3, src.pos) + end + + def test_copy_stream_strio_off + src = StringIO.new("abcd") + dst = StringIO.new + assert_raise(ArgumentError) { + IO.copy_stream(src, dst, 3, 1) + } + end + end