diff --git a/lib/pundit.rb b/lib/pundit.rb index efb49e7..29e84dc 100644 --- a/lib/pundit.rb +++ b/lib/pundit.rb @@ -63,12 +63,14 @@ module Pundit def policy_scope(scope) @_policy_scoped = true - Pundit.policy_scope!(pundit_user, scope) + @policy_scope or Pundit.policy_scope!(pundit_user, scope) end + attr_writer :policy_scope def policy(record) - Pundit.policy!(pundit_user, record) + @policy or Pundit.policy!(pundit_user, record) end + attr_writer :policy def pundit_user current_user diff --git a/spec/pundit_spec.rb b/spec/pundit_spec.rb index 9dc7a89..3bcfe6a 100644 --- a/spec/pundit_spec.rb +++ b/spec/pundit_spec.rb @@ -242,6 +242,13 @@ describe Pundit do it "throws an exception if the given policy can't be found" do expect { controller.policy(article) }.to raise_error(Pundit::NotDefinedError) end + + it "allows policy to be injected" do + new_policy = OpenStruct.new + controller.policy = new_policy + + controller.policy(post).should == new_policy + end end describe ".policy_scope" do @@ -252,5 +259,12 @@ describe Pundit do it "throws an exception if the given policy can't be found" do expect { controller.policy_scope(Article) }.to raise_error(Pundit::NotDefinedError) end + + it "allows policy_scope to be injected" do + new_scope = OpenStruct.new + controller.policy_scope = new_scope + + controller.policy_scope(post).should == new_scope + end end end