From b193041b992e5ce0ae1a07735fbdc53a739b5434 Mon Sep 17 00:00:00 2001 From: Jeremy Evans Date: Wed, 25 Sep 2019 17:57:00 -0700 Subject: [PATCH] Fix keyword argument separation issues in Fiber#resume --- cont.c | 41 +++++++++++++++------ include/ruby/intern.h | 2 + test/ruby/test_keyword.rb | 77 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 11 deletions(-) diff --git a/cont.c b/cont.c index d48f845369..d19e45aa8b 100644 --- a/cont.c +++ b/cont.c @@ -179,6 +179,7 @@ struct fiber_pool { typedef struct rb_context_struct { enum context_type type; int argc; + int kw_splat; VALUE self; VALUE value; @@ -1777,6 +1778,9 @@ rb_fiber_new(rb_block_call_func_t func, VALUE obj) static void rb_fiber_terminate(rb_fiber_t *fiber, int need_interrupt); +#define PASS_KW_SPLAT (rb_empty_keyword_given_p() ? RB_PASS_EMPTY_KEYWORDS : rb_keyword_given_p()) +extern VALUE rb_adjust_argv_kw_splat(int *argc, const VALUE **argv, int *kw_splat); + void rb_fiber_start(void) { @@ -1794,6 +1798,7 @@ rb_fiber_start(void) rb_context_t *cont = &VAR_FROM_MEMORY(fiber)->cont; int argc; const VALUE *argv, args = cont->value; + int kw_splat = cont->kw_splat; GetProcPtr(fiber->first_proc, proc); argv = (argc = cont->argc) > 1 ? RARRAY_CONST_PTR(args) : &args; cont->value = Qnil; @@ -1802,7 +1807,8 @@ rb_fiber_start(void) th->ec->root_svar = Qfalse; EXEC_EVENT_HOOK(th->ec, RUBY_EVENT_FIBER_SWITCH, th->self, 0, 0, 0, Qnil); - cont->value = rb_vm_invoke_proc(th->ec, proc, argc, argv, VM_NO_KEYWORDS, VM_BLOCK_HANDLER_NONE); + rb_adjust_argv_kw_splat(&argc, &argv, &kw_splat); + cont->value = rb_vm_invoke_proc(th->ec, proc, argc, argv, kw_splat, VM_BLOCK_HANDLER_NONE); } EC_POP_TAG(); @@ -1965,7 +1971,7 @@ fiber_store(rb_fiber_t *next_fiber, rb_thread_t *th) } static inline VALUE -fiber_switch(rb_fiber_t *fiber, int argc, const VALUE *argv, int is_resume) +fiber_switch(rb_fiber_t *fiber, int argc, const VALUE *argv, int is_resume, int kw_splat) { VALUE value; rb_context_t *cont = &fiber->cont; @@ -2017,6 +2023,7 @@ fiber_switch(rb_fiber_t *fiber, int argc, const VALUE *argv, int is_resume) VM_ASSERT(FIBER_RUNNABLE_P(fiber)); cont->argc = argc; + cont->kw_splat = kw_splat; cont->value = make_passing_arg(argc, argv); value = fiber_store(fiber, th); @@ -2035,7 +2042,7 @@ fiber_switch(rb_fiber_t *fiber, int argc, const VALUE *argv, int is_resume) VALUE rb_fiber_transfer(VALUE fiber_value, int argc, const VALUE *argv) { - return fiber_switch(fiber_ptr(fiber_value), argc, argv, 0); + return fiber_switch(fiber_ptr(fiber_value), argc, argv, 0, RB_NO_KEYWORDS); } void @@ -2060,11 +2067,11 @@ rb_fiber_terminate(rb_fiber_t *fiber, int need_interrupt) next_fiber = return_fiber(); if (need_interrupt) RUBY_VM_SET_INTERRUPT(&next_fiber->cont.saved_ec); - fiber_switch(next_fiber, 1, &value, 0); + fiber_switch(next_fiber, 1, &value, 0, RB_NO_KEYWORDS); } VALUE -rb_fiber_resume(VALUE fiber_value, int argc, const VALUE *argv) +rb_fiber_resume_kw(VALUE fiber_value, int argc, const VALUE *argv, int kw_splat) { rb_fiber_t *fiber = fiber_ptr(fiber_value); @@ -2080,13 +2087,25 @@ rb_fiber_resume(VALUE fiber_value, int argc, const VALUE *argv) rb_raise(rb_eFiberError, "cannot resume transferred Fiber"); } - return fiber_switch(fiber, argc, argv, 1); + return fiber_switch(fiber, argc, argv, 1, kw_splat); +} + +VALUE +rb_fiber_resume(VALUE fiber_value, int argc, const VALUE *argv) +{ + return rb_fiber_resume_kw(fiber_value, argc, argv, RB_NO_KEYWORDS); +} + +VALUE +rb_fiber_yield_kw(int argc, const VALUE *argv, int kw_splat) +{ + return fiber_switch(return_fiber(), argc, argv, 0, kw_splat); } VALUE rb_fiber_yield(int argc, const VALUE *argv) { - return fiber_switch(return_fiber(), argc, argv, 0); + return fiber_switch(return_fiber(), argc, argv, 0, RB_NO_KEYWORDS); } void @@ -2130,7 +2149,7 @@ rb_fiber_alive_p(VALUE fiber_value) static VALUE rb_fiber_m_resume(int argc, VALUE *argv, VALUE fiber) { - return rb_fiber_resume(fiber, argc, argv); + return rb_fiber_resume_kw(fiber, argc, argv, PASS_KW_SPLAT); } /* @@ -2156,7 +2175,7 @@ static VALUE rb_fiber_raise(int argc, VALUE *argv, VALUE fiber) { VALUE exc = rb_make_exception(argc, argv); - return rb_fiber_resume(fiber, -1, &exc); + return rb_fiber_resume_kw(fiber, -1, &exc, RB_NO_KEYWORDS); } /* @@ -2209,7 +2228,7 @@ rb_fiber_m_transfer(int argc, VALUE *argv, VALUE fiber_value) { rb_fiber_t *fiber = fiber_ptr(fiber_value); fiber->transferred = 1; - return fiber_switch(fiber, argc, argv, 0); + return fiber_switch(fiber, argc, argv, 0, PASS_KW_SPLAT); } /* @@ -2225,7 +2244,7 @@ rb_fiber_m_transfer(int argc, VALUE *argv, VALUE fiber_value) static VALUE rb_fiber_s_yield(int argc, VALUE *argv, VALUE klass) { - return rb_fiber_yield(argc, argv); + return rb_fiber_yield_kw(argc, argv, PASS_KW_SPLAT); } /* diff --git a/include/ruby/intern.h b/include/ruby/intern.h index a44d7fd2a0..4af254d920 100644 --- a/include/ruby/intern.h +++ b/include/ruby/intern.h @@ -240,7 +240,9 @@ NORETURN(void rb_cmperr(VALUE, VALUE)); /* cont.c */ VALUE rb_fiber_new(rb_block_call_func_t, VALUE); VALUE rb_fiber_resume(VALUE fib, int argc, const VALUE *argv); +VALUE rb_fiber_resume_kw(VALUE fib, int argc, const VALUE *argv, int kw_splat); VALUE rb_fiber_yield(int argc, const VALUE *argv); +VALUE rb_fiber_yield_kw(int argc, const VALUE *argv, int kw_splat); VALUE rb_fiber_current(void); VALUE rb_fiber_alive_p(VALUE); /* enum.c */ diff --git a/test/ruby/test_keyword.rb b/test/ruby/test_keyword.rb index dbced97b7e..4bda9d5fb8 100644 --- a/test/ruby/test_keyword.rb +++ b/test/ruby/test_keyword.rb @@ -764,6 +764,83 @@ class TestKeywordArguments < Test::Unit::TestCase Thread.report_on_exception = true end + def test_Fiber_resume_kwsplat + kw = {} + h = {:a=>1} + h2 = {'a'=>1} + h3 = {'a'=>1, :a=>1} + + t = Fiber + f = -> { true } + assert_equal(true, t.new(&f).resume(**{})) + assert_equal(true, t.new(&f).resume(**kw)) + assert_raise(ArgumentError) { t.new(&f).resume(**h) } + assert_raise(ArgumentError) { t.new(&f).resume(a: 1) } + assert_raise(ArgumentError) { t.new(&f).resume(**h2) } + assert_raise(ArgumentError) { t.new(&f).resume(**h3) } + + f = ->(a) { a } + assert_warn(/The keyword argument is passed as the last hash parameter/m) do + assert_equal(kw, t.new(&f).resume(**{})) + end + assert_warn(/The keyword argument is passed as the last hash parameter/m) do + assert_equal(kw, t.new(&f).resume(**kw)) + end + assert_equal(h, t.new(&f).resume(**h)) + assert_equal(h, t.new(&f).resume(a: 1)) + assert_equal(h2, t.new(&f).resume(**h2)) + assert_equal(h3, t.new(&f).resume(**h3)) + assert_equal(h3, t.new(&f).resume(a: 1, **h2)) + + f = ->(**x) { x } + assert_equal(kw, t.new(&f).resume(**{})) + assert_equal(kw, t.new(&f).resume(**kw)) + assert_equal(h, t.new(&f).resume(**h)) + assert_equal(h, t.new(&f).resume(a: 1)) + assert_equal(h2, t.new(&f).resume(**h2)) + assert_equal(h3, t.new(&f).resume(**h3)) + assert_equal(h3, t.new(&f).resume(a: 1, **h2)) + assert_warn(/The last argument is used as the keyword parameter.*for method/m) do + assert_equal(h, t.new(&f).resume(h)) + end + assert_raise(ArgumentError) { t.new(&f).resume(h2) } + assert_warn(/The last argument is split into positional and keyword parameters.*for method/m) do + assert_raise(ArgumentError) { t.new(&f).resume(h3) } + end + + f = ->(a, **x) { [a,x] } + assert_warn(/The keyword argument is passed as the last hash parameter/) do + assert_equal([{}, {}], t.new(&f).resume(**{})) + end + assert_warn(/The keyword argument is passed as the last hash parameter/) do + assert_equal([{}, {}], t.new(&f).resume(**kw)) + end + assert_warn(/The keyword argument is passed as the last hash parameter/) do + assert_equal([h, {}], t.new(&f).resume(**h)) + end + assert_warn(/The keyword argument is passed as the last hash parameter/) do + assert_equal([h, {}], t.new(&f).resume(a: 1)) + end + assert_warn(/The keyword argument is passed as the last hash parameter/) do + assert_equal([h2, {}], t.new(&f).resume(**h2)) + end + assert_warn(/The keyword argument is passed as the last hash parameter/) do + assert_equal([h3, {}], t.new(&f).resume(**h3)) + end + assert_warn(/The keyword argument is passed as the last hash parameter/) do + assert_equal([h3, {}], t.new(&f).resume(a: 1, **h2)) + end + + f = ->(a=1, **x) { [a, x] } + assert_equal([1, kw], t.new(&f).resume(**{})) + assert_equal([1, kw], t.new(&f).resume(**kw)) + assert_equal([1, h], t.new(&f).resume(**h)) + assert_equal([1, h], t.new(&f).resume(a: 1)) + assert_equal([1, h2], t.new(&f).resume(**h2)) + assert_equal([1, h3], t.new(&f).resume(**h3)) + assert_equal([1, h3], t.new(&f).resume(a: 1, **h2)) + end + def test_Class_new_kwsplat_call kw = {} h = {:a=>1}