1
0
Fork 0
mirror of https://github.com/rails/rails.git synced 2022-11-09 12:12:34 -05:00

Move Arel attribute normalization into arel_table

In Active Record internal, `arel_table` is not directly used but
`arel_attribute` is used, since `arel_table` doesn't normalize an
attribute name as a string, and doesn't resolve attribute aliases.

For the above reason, `arel_attribute` should be used rather than
`arel_table`, but most people directly use `arel_table`, both
`arel_table` and `arel_attribute` are private API though.

Although I'd not recommend using private API, `arel_table` is actually
widely used, and it is also problematic for unscopeable queries and
hash-like relation merging friendly, as I explained at #39863.

To resolve the issue, this change moves Arel attribute normalization
(attribute name as a string, and attribute alias resolution) into
`arel_table`.
This commit is contained in:
Ryuta Kamizono 2020-07-19 20:00:42 +09:00
parent c418fb9fe9
commit 1ac40f16c5
21 changed files with 50 additions and 52 deletions

View file

@ -51,11 +51,11 @@ module ActiveRecord
@connection = connection @connection = connection
end end
def aliased_table_for(table_name, aliased_name, type_caster) def aliased_table_for(table_name, aliased_name, klass)
if aliases[table_name].zero? if aliases[table_name].zero?
# If it's zero, we can have our table_name # If it's zero, we can have our table_name
aliases[table_name] = 1 aliases[table_name] = 1
Arel::Table.new(table_name, type_caster: type_caster) Arel::Table.new(table_name, klass: klass)
else else
# Otherwise, we need to use an alias # Otherwise, we need to use an alias
aliased_name = @connection.table_alias_for(aliased_name) aliased_name = @connection.table_alias_for(aliased_name)
@ -68,7 +68,7 @@ module ActiveRecord
else else
aliased_name aliased_name
end end
Arel::Table.new(table_name, type_caster: type_caster).alias(table_alias) Arel::Table.new(table_name, klass: klass).alias(table_alias)
end end
end end

View file

@ -109,7 +109,7 @@ module ActiveRecord
aliased_table = tracker.aliased_table_for( aliased_table = tracker.aliased_table_for(
refl.table_name, refl.table_name,
refl.alias_candidate(name), refl.alias_candidate(name),
refl.klass.type_caster refl.klass
) )
chain << ReflectionProxy.new(refl, aliased_table) chain << ReflectionProxy.new(refl, aliased_table)
end end

View file

@ -176,7 +176,7 @@ module ActiveRecord
alias_tracker.aliased_table_for( alias_tracker.aliased_table_for(
reflection.table_name, reflection.table_name,
table_alias_for(reflection, parent, reflection != child.reflection), table_alias_for(reflection, parent, reflection != child.reflection),
reflection.klass.type_caster reflection.klass
) )
end.concat child.children.flat_map { |c| make_constraints(child, c, join_type) } end.concat child.children.flat_map { |c| make_constraints(child, c, join_type) }
end end

View file

@ -289,12 +289,10 @@ module ActiveRecord
# scope :published_and_commented, -> { published.and(arel_table[:comments_count].gt(0)) } # scope :published_and_commented, -> { published.and(arel_table[:comments_count].gt(0)) }
# end # end
def arel_table # :nodoc: def arel_table # :nodoc:
@arel_table ||= Arel::Table.new(table_name, type_caster: type_caster) @arel_table ||= Arel::Table.new(table_name, klass: self)
end end
def arel_attribute(name, table = arel_table) # :nodoc: def arel_attribute(name, table = arel_table) # :nodoc:
name = name.to_s
name = attribute_aliases[name] || name
table[name] table[name]
end end

View file

@ -260,7 +260,7 @@ module ActiveRecord
end end
def type_condition(table = arel_table) def type_condition(table = arel_table)
sti_column = arel_attribute(inheritance_column, table) sti_column = table[inheritance_column]
sti_names = ([self] + descendants).map(&:sti_name) sti_names = ([self] + descendants).map(&:sti_name)
predicate_builder.build(sti_column, sti_names) predicate_builder.build(sti_column, sti_names)

View file

@ -414,7 +414,7 @@ module ActiveRecord
def _substitute_values(values) def _substitute_values(values)
values.map do |name, value| values.map do |name, value|
attr = arel_attribute(name) attr = arel_table[name]
bind = predicate_builder.build_bind_attribute(name, value) bind = predicate_builder.build_bind_attribute(name, value)
[attr, bind] [attr, bind]
end end

View file

@ -1052,7 +1052,7 @@ module ActiveRecord
end end
def aliased_table def aliased_table
@aliased_table ||= Arel::Table.new(table_name, type_caster: klass.type_caster) @aliased_table ||= Arel::Table.new(table_name, klass: klass)
end end
def join_primary_key(klass = self.klass) def join_primary_key(klass = self.klass)

View file

@ -39,7 +39,7 @@ module ActiveRecord
end end
def arel_attribute(name) # :nodoc: def arel_attribute(name) # :nodoc:
klass.arel_attribute(name, table) table[name]
end end
def bind_attribute(name, value) # :nodoc: def bind_attribute(name, value) # :nodoc:
@ -48,7 +48,7 @@ module ActiveRecord
value = value.read_attribute(reflection.klass.primary_key) unless value.nil? value = value.read_attribute(reflection.klass.primary_key) unless value.nil?
end end
attr = arel_attribute(name) attr = table[name]
bind = predicate_builder.build_bind_attribute(attr.name, value) bind = predicate_builder.build_bind_attribute(attr.name, value)
yield attr, bind yield attr, bind
end end
@ -352,7 +352,7 @@ module ActiveRecord
else else
collection = eager_loading? ? apply_join_dependency : self collection = eager_loading? ? apply_join_dependency : self
column = connection.visitor.compile(arel_attribute(timestamp_column)) column = connection.visitor.compile(table[timestamp_column])
select_values = "COUNT(*) AS #{connection.quote_column_name("size")}, MAX(%s) AS timestamp" select_values = "COUNT(*) AS #{connection.quote_column_name("size")}, MAX(%s) AS timestamp"
if collection.has_limit_or_offset? if collection.has_limit_or_offset?
@ -447,7 +447,7 @@ module ActiveRecord
stmt = Arel::UpdateManager.new stmt = Arel::UpdateManager.new
stmt.table(arel.join_sources.empty? ? table : arel.source) stmt.table(arel.join_sources.empty? ? table : arel.source)
stmt.key = arel_attribute(primary_key) stmt.key = table[primary_key]
stmt.take(arel.limit) stmt.take(arel.limit)
stmt.offset(arel.offset) stmt.offset(arel.offset)
stmt.order(*arel.orders) stmt.order(*arel.orders)
@ -457,7 +457,7 @@ module ActiveRecord
if klass.locking_enabled? && if klass.locking_enabled? &&
!updates.key?(klass.locking_column) && !updates.key?(klass.locking_column) &&
!updates.key?(klass.locking_column.to_sym) !updates.key?(klass.locking_column.to_sym)
attr = arel_attribute(klass.locking_column) attr = table[klass.locking_column]
updates[attr.name] = _increment_attribute(attr) updates[attr.name] = _increment_attribute(attr)
end end
stmt.set _substitute_values(updates) stmt.set _substitute_values(updates)
@ -493,7 +493,7 @@ module ActiveRecord
updates = {} updates = {}
counters.each do |counter_name, value| counters.each do |counter_name, value|
attr = arel_attribute(counter_name) attr = table[counter_name]
updates[attr.name] = _increment_attribute(attr, value) updates[attr.name] = _increment_attribute(attr, value)
end end
@ -589,7 +589,7 @@ module ActiveRecord
stmt = Arel::DeleteManager.new stmt = Arel::DeleteManager.new
stmt.from(arel.join_sources.empty? ? table : arel.source) stmt.from(arel.join_sources.empty? ? table : arel.source)
stmt.key = arel_attribute(primary_key) stmt.key = table[primary_key]
stmt.take(arel.limit) stmt.take(arel.limit)
stmt.offset(arel.offset) stmt.offset(arel.offset)
stmt.order(*arel.orders) stmt.order(*arel.orders)
@ -813,7 +813,7 @@ module ActiveRecord
def _substitute_values(values) def _substitute_values(values)
values.map do |name, value| values.map do |name, value|
attr = arel_attribute(name) attr = table[name]
unless Arel.arel_node?(value) unless Arel.arel_node?(value)
type = klass.type_for_attribute(attr.name) type = klass.type_for_attribute(attr.name)
value = predicate_builder.build_bind_attribute(attr.name, type.cast(value)) value = predicate_builder.build_bind_attribute(attr.name, type.cast(value))

View file

@ -280,7 +280,7 @@ module ActiveRecord
end end
def batch_order(order) def batch_order(order)
arel_attribute(primary_key).public_send(order) table[primary_key].public_send(order)
end end
def act_on_ignored_order(error_on_ignore) def act_on_ignored_order(error_on_ignore)

View file

@ -410,7 +410,7 @@ module ActiveRecord
def limited_ids_for(relation) def limited_ids_for(relation)
values = @klass.connection.columns_for_distinct( values = @klass.connection.columns_for_distinct(
connection.visitor.compile(arel_attribute(primary_key)), connection.visitor.compile(table[primary_key]),
relation.order_values relation.order_values
) )
@ -562,9 +562,9 @@ module ActiveRecord
def ordered_relation def ordered_relation
if order_values.empty? && (implicit_order_column || primary_key) if order_values.empty? && (implicit_order_column || primary_key)
if implicit_order_column && primary_key && implicit_order_column != primary_key if implicit_order_column && primary_key && implicit_order_column != primary_key
order(arel_attribute(implicit_order_column).asc, arel_attribute(primary_key).asc) order(table[implicit_order_column].asc, table[primary_key].asc)
else else
order(arel_attribute(implicit_order_column || primary_key).asc) order(table[implicit_order_column || primary_key].asc)
end end
else else
self self

View file

@ -9,7 +9,7 @@ module ActiveRecord
end end
if value.select_values.empty? if value.select_values.empty?
value = value.select(value.arel_attribute(value.klass.primary_key)) value = value.select(value.table[value.klass.primary_key])
end end
attribute.in(value.arel) attribute.in(value.arel)

View file

@ -1304,7 +1304,7 @@ module ActiveRecord
if select_values.any? if select_values.any?
arel.project(*arel_columns(select_values.uniq)) arel.project(*arel_columns(select_values.uniq))
elsif klass.ignored_columns.any? elsif klass.ignored_columns.any?
arel.project(*klass.column_names.map { |field| arel_attribute(field) }) arel.project(*klass.column_names.map { |field| table[field] })
else else
arel.project(table[Arel.star]) arel.project(table[Arel.star])
end end
@ -1332,7 +1332,7 @@ module ActiveRecord
from = from_clause.name || from_clause.value from = from_clause.name || from_clause.value
if klass.columns_hash.key?(field) && (!from || table_name_matches?(from)) if klass.columns_hash.key?(field) && (!from || table_name_matches?(from))
arel_attribute(field) table[field]
elsif field.match?(/\A\w+\.\w+\z/) elsif field.match?(/\A\w+\.\w+\z/)
table, column = field.split(".") table, column = field.split(".")
predicate_builder.resolve_arel_attribute(table, column) do predicate_builder.resolve_arel_attribute(table, column) do
@ -1351,7 +1351,7 @@ module ActiveRecord
def reverse_sql_order(order_query) def reverse_sql_order(order_query)
if order_query.empty? if order_query.empty?
return [arel_attribute(primary_key).desc] if primary_key return [table[primary_key].desc] if primary_key
raise IrreversibleOrderError, raise IrreversibleOrderError,
"Relation has no current order and table has no primary key to be used as default order" "Relation has no current order and table has no primary key to be used as default order"
end end
@ -1457,7 +1457,7 @@ module ActiveRecord
def order_column(field) def order_column(field)
arel_column(field) do |attr_name| arel_column(field) do |attr_name|
if attr_name == "count" && !group_values.empty? if attr_name == "count" && !group_values.empty?
arel_attribute(attr_name) table[attr_name]
else else
Arel.sql(connection.quote_table_name(attr_name)) Arel.sql(connection.quote_table_name(attr_name))
end end

View file

@ -11,12 +11,8 @@ module ActiveRecord
end end
def arel_attribute(column_name) def arel_attribute(column_name)
if klass
klass.arel_attribute(column_name, arel_table)
else
arel_table[column_name] arel_table[column_name]
end end
end
def type(column_name) def type(column_name)
arel_table.type_for_attribute(column_name) arel_table.type_for_attribute(column_name)

View file

@ -8,7 +8,7 @@ module Arel # :nodoc: all
alias :table_alias :name alias :table_alias :name
def [](name) def [](name)
Attribute.new(self, name) relation.is_a?(Table) ? relation[name, self] : Attribute.new(self, name)
end end
def table_name def table_name

View file

@ -14,8 +14,9 @@ module Arel # :nodoc: all
# TableAlias and Table both have a #table_name which is the name of the underlying table # TableAlias and Table both have a #table_name which is the name of the underlying table
alias :table_name :name alias :table_name :name
def initialize(name, as: nil, type_caster: nil) def initialize(name, as: nil, klass: nil, type_caster: klass&.type_caster)
@name = name.to_s @name = name.to_s
@klass = klass
@type_caster = type_caster @type_caster = type_caster
# Sometime AR sends an :as parameter to table, to let the table know # Sometime AR sends an :as parameter to table, to let the table know
@ -79,8 +80,10 @@ module Arel # :nodoc: all
from.having expr from.having expr
end end
def [](name) def [](name, table = self)
::Arel::Attribute.new self, name name = name.to_s if name.is_a?(Symbol)
name = @klass.attribute_aliases[name] || name if @klass
Attribute.new(table, name)
end end
def hash def hash

View file

@ -8,19 +8,19 @@ class PostgresqlCaseInsensitiveTest < ActiveRecord::PostgreSQLTestCase
def test_case_insensitiveness def test_case_insensitiveness
connection = ActiveRecord::Base.connection connection = ActiveRecord::Base.connection
attr = Default.arel_attribute(:char1) attr = Default.arel_table[:char1]
comparison = connection.case_insensitive_comparison(attr, nil) comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql) assert_match(/lower/i, comparison.to_sql)
attr = Default.arel_attribute(:char2) attr = Default.arel_table[:char2]
comparison = connection.case_insensitive_comparison(attr, nil) comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql) assert_match(/lower/i, comparison.to_sql)
attr = Default.arel_attribute(:char3) attr = Default.arel_table[:char3]
comparison = connection.case_insensitive_comparison(attr, nil) comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql) assert_match(/lower/i, comparison.to_sql)
attr = Default.arel_attribute(:multiline_default) attr = Default.arel_table[:multiline_default]
comparison = connection.case_insensitive_comparison(attr, nil) comparison = connection.case_insensitive_comparison(attr, nil)
assert_match(/lower/i, comparison.to_sql) assert_match(/lower/i, comparison.to_sql)
end end

View file

@ -188,7 +188,7 @@ module Arel
describe "when given a Symbol" do describe "when given a Symbol" do
it "manufactures an attribute if the symbol names an attribute within the relation" do it "manufactures an attribute if the symbol names an attribute within the relation" do
column = @relation[:id] column = @relation[:id]
_(column.name).must_equal :id _(column.name).must_equal "id"
end end
end end
end end

View file

@ -71,7 +71,7 @@ class InnerJoinAssociationTest < ActiveRecord::TestCase
def test_deduplicate_joins def test_deduplicate_joins
posts = Post.arel_table posts = Post.arel_table
constraint = posts[:author_id].eq(Author.arel_attribute(:id)) constraint = posts[:author_id].eq(Author.arel_table[:id])
authors = Author.joins(posts.create_join(posts, posts.create_on(constraint))) authors = Author.joins(posts.create_join(posts, posts.create_on(constraint)))
authors = authors.joins(:author_address).merge(authors.where("posts.type": "SpecialPost")) authors = authors.joins(:author_address).merge(authors.where("posts.type": "SpecialPost"))

View file

@ -78,6 +78,11 @@ class BasicsTest < ActiveRecord::TestCase
assert_equal "Post::GeneratedRelationMethods", mod.inspect assert_equal "Post::GeneratedRelationMethods", mod.inspect
end end
def test_arel_attribute_normalization
assert_equal Post.arel_table["body"], Post.arel_table[:body]
assert_equal Post.arel_table["body"], Post.arel_table[:text]
end
def test_incomplete_schema_loading def test_incomplete_schema_loading
topic = Topic.first topic = Topic.first
payload = { foo: 42 } payload = { foo: 42 }

View file

@ -1261,7 +1261,7 @@ class RelationTest < ActiveRecord::TestCase
assert_predicate same_parrot, :persisted? assert_predicate same_parrot, :persisted?
assert_equal parrot, same_parrot assert_equal parrot, same_parrot
canary = Bird.where(Bird.arel_attribute(:color).is_distinct_from("green")).first_or_create(name: "canary") canary = Bird.where(Bird.arel_table[:color].is_distinct_from("green")).first_or_create(name: "canary")
assert_equal "canary", canary.name assert_equal "canary", canary.name
assert_nil canary.color assert_nil canary.color
end end
@ -1385,7 +1385,7 @@ class RelationTest < ActiveRecord::TestCase
assert_equal "parrot", parrot.name assert_equal "parrot", parrot.name
assert_equal "green", parrot.color assert_equal "green", parrot.color
canary = Bird.where(Bird.arel_attribute(:color).is_distinct_from("green")).first_or_initialize(name: "canary") canary = Bird.where(Bird.arel_table[:color].is_distinct_from("green")).first_or_initialize(name: "canary")
assert_equal "canary", canary.name assert_equal "canary", canary.name
assert_nil canary.color assert_nil canary.color
end end
@ -1963,7 +1963,7 @@ class RelationTest < ActiveRecord::TestCase
assert_equal post, custom_post_relation.joins(:author).where!(title: post.title).take assert_equal post, custom_post_relation.joins(:author).where!(title: post.title).take
end end
test "arel_attribute respects a custom table" do test "arel_table respects a custom table" do
assert_equal [posts(:sti_comments)], custom_post_relation.ranked_by_comments.limit_by(1).to_a assert_equal [posts(:sti_comments)], custom_post_relation.ranked_by_comments.limit_by(1).to_a
end end
@ -2093,7 +2093,7 @@ class RelationTest < ActiveRecord::TestCase
end end
def test_unscope_with_arel_sql def test_unscope_with_arel_sql
posts = Post.where(Arel.sql("'Welcome to the weblog'").eq(Post.arel_attribute(:title))) posts = Post.where(Arel.sql("'Welcome to the weblog'").eq(Post.arel_table[:title]))
assert_equal 1, posts.count assert_equal 1, posts.count
assert_equal Post.count, posts.unscope(where: :title).count assert_equal Post.count, posts.unscope(where: :title).count

View file

@ -28,7 +28,7 @@ class Post < ActiveRecord::Base
scope :containing_the_letter_a, -> { where("body LIKE '%a%'") } scope :containing_the_letter_a, -> { where("body LIKE '%a%'") }
scope :titled_with_an_apostrophe, -> { where("title LIKE '%''%'") } scope :titled_with_an_apostrophe, -> { where("title LIKE '%''%'") }
scope :ranked_by_comments, -> { order(arel_attribute(:comments_count).desc) } scope :ranked_by_comments, -> { order(table[:comments_count].desc) }
scope :limit_by, lambda { |l| limit(l) } scope :limit_by, lambda { |l| limit(l) }
scope :locked, -> { lock } scope :locked, -> { lock }
@ -339,10 +339,6 @@ class FakeKlass
sql sql
end end
def arel_attribute(name, table)
table[name]
end
def disallow_raw_sql!(*args) def disallow_raw_sql!(*args)
# noop # noop
end end