1
0
Fork 0
mirror of https://github.com/rails/rails.git synced 2022-11-09 12:12:34 -05:00

Merge pull request #23004 from matthewd/default-scope-sti

Make default scopes + STI happy again
This commit is contained in:
Matthew Draper 2016-01-12 13:59:46 +10:30
commit d16645a37a
7 changed files with 61 additions and 19 deletions

View file

@ -275,7 +275,7 @@ module ActiveRecord
def relation # :nodoc:
relation = Relation.create(self, arel_table, predicate_builder)
if finder_needs_type_condition?
if finder_needs_type_condition? && !ignore_default_scope?
relation.where(type_condition).create_with(inheritance_column.to_sym => sti_name)
else
relation

View file

@ -11,11 +11,11 @@ module ActiveRecord
module ClassMethods
def current_scope #:nodoc:
ScopeRegistry.value_for(:current_scope, self.to_s)
ScopeRegistry.value_for(:current_scope, self)
end
def current_scope=(scope) #:nodoc:
ScopeRegistry.set_value_for(:current_scope, self.to_s, scope)
ScopeRegistry.set_value_for(:current_scope, self, scope)
end
# Collects attributes from scopes that should be applied when creating
@ -53,18 +53,18 @@ module ActiveRecord
# following code:
#
# registry = ActiveRecord::Scoping::ScopeRegistry
# registry.set_value_for(:current_scope, "Board", some_new_scope)
# registry.set_value_for(:current_scope, Board, some_new_scope)
#
# Now when you run:
#
# registry.value_for(:current_scope, "Board")
# registry.value_for(:current_scope, Board)
#
# You will obtain whatever was defined in +some_new_scope+. The #value_for
# and #set_value_for methods are delegated to the current ScopeRegistry
# object, so the above example code can also be called as:
#
# ActiveRecord::Scoping::ScopeRegistry.set_value_for(:current_scope,
# "Board", some_new_scope)
# Board, some_new_scope)
class ScopeRegistry # :nodoc:
extend ActiveSupport::PerThreadRegistry
@ -74,16 +74,22 @@ module ActiveRecord
@registry = Hash.new { |hash, key| hash[key] = {} }
end
# Obtains the value for a given +scope_name+ and +variable_name+.
def value_for(scope_type, variable_name)
# Obtains the value for a given +scope_type+ and +model+.
def value_for(scope_type, model)
raise_invalid_scope_type!(scope_type)
@registry[scope_type][variable_name]
klass = model
base = model.base_class
while klass <= base
value = @registry[scope_type][klass.name]
return value if value
klass = klass.superclass
end
end
# Sets the +value+ for a given +scope_type+ and +variable_name+.
def set_value_for(scope_type, variable_name, value)
# Sets the +value+ for a given +scope_type+ and +model+.
def set_value_for(scope_type, model, value)
raise_invalid_scope_type!(scope_type)
@registry[scope_type][variable_name] = value
@registry[scope_type][model.name] = value
end
private

View file

@ -122,11 +122,11 @@ module ActiveRecord
end
def ignore_default_scope? # :nodoc:
ScopeRegistry.value_for(:ignore_default_scope, self)
ScopeRegistry.value_for(:ignore_default_scope, base_class)
end
def ignore_default_scope=(ignore) # :nodoc:
ScopeRegistry.set_value_for(:ignore_default_scope, self, ignore)
ScopeRegistry.set_value_for(:ignore_default_scope, base_class, ignore)
end
# The ignore_default_scope flag is used to prevent an infinite recursion

View file

@ -1275,9 +1275,10 @@ class BasicsTest < ActiveRecord::TestCase
UnloadablePost.send(:current_scope=, UnloadablePost.all)
UnloadablePost.unloadable
assert_not_nil ActiveRecord::Scoping::ScopeRegistry.value_for(:current_scope, "UnloadablePost")
klass = UnloadablePost
assert_not_nil ActiveRecord::Scoping::ScopeRegistry.value_for(:current_scope, klass)
ActiveSupport::Dependencies.remove_unloadable_constants!
assert_nil ActiveRecord::Scoping::ScopeRegistry.value_for(:current_scope, "UnloadablePost")
assert_nil ActiveRecord::Scoping::ScopeRegistry.value_for(:current_scope, klass)
ensure
Object.class_eval{ remove_const :UnloadablePost } if defined?(UnloadablePost)
end

View file

@ -459,4 +459,18 @@ class DefaultScopingTest < ActiveRecord::TestCase
scope = Bus.all
assert_equal scope.where_clause.ast.children.length, 1
end
def test_sti_conditions_are_not_carried_in_default_scope
ConditionalStiPost.create! body: ''
SubConditionalStiPost.create! body: ''
SubConditionalStiPost.create! title: 'Hello world', body: ''
assert_equal 2, ConditionalStiPost.count
assert_equal 2, ConditionalStiPost.all.to_a.size
assert_equal 3, ConditionalStiPost.unscope(where: :title).to_a.size
assert_equal 1, SubConditionalStiPost.count
assert_equal 1, SubConditionalStiPost.all.to_a.size
assert_equal 2, SubConditionalStiPost.unscope(where: :title).to_a.size
end
end

View file

@ -209,9 +209,23 @@ class RelationScopingTest < ActiveRecord::TestCase
assert_not_equal [], Developer.all
end
def test_current_scope_does_not_pollute_other_subclasses
Post.none.scoping do
assert StiPost.all.any?
def test_current_scope_does_not_pollute_sibling_subclasses
Comment.none.scoping do
assert_not SpecialComment.all.any?
assert_not VerySpecialComment.all.any?
assert_not SubSpecialComment.all.any?
end
SpecialComment.none.scoping do
assert Comment.all.any?
assert VerySpecialComment.all.any?
assert_not SubSpecialComment.all.any?
end
SubSpecialComment.none.scoping do
assert Comment.all.any?
assert VerySpecialComment.all.any?
assert SpecialComment.all.any?
end
end
end

View file

@ -263,3 +263,10 @@ end
class SerializedPost < ActiveRecord::Base
serialize :title
end
class ConditionalStiPost < Post
default_scope { where(title: 'Untitled') }
end
class SubConditionalStiPost < ConditionalStiPost
end