From 9641d0b2ef1e88b2fd61ec5771ef3d62a82f7c8c Mon Sep 17 00:00:00 2001 From: Ryuta Kamizono Date: Fri, 29 May 2020 03:07:26 +0900 Subject: [PATCH] Support type casting for grandchild's attributes Related to #39292. Fixes #39460. --- .../active_record/relation/calculations.rb | 21 +++------ .../relation/predicate_builder.rb | 13 +++--- .../active_record/relation/query_methods.rb | 34 +++++++++++++- .../relation/where_clause_factory.rb | 4 +- .../lib/active_record/table_metadata.rb | 44 +++++++++---------- activerecord/test/cases/calculations_test.rb | 20 +++++---- activerecord/test/cases/inheritance_test.rb | 4 +- activerecord/test/models/company.rb | 4 +- 8 files changed, 82 insertions(+), 62 deletions(-) diff --git a/activerecord/lib/active_record/relation/calculations.rb b/activerecord/lib/active_record/relation/calculations.rb index ebf9b19e32..bb0435524b 100644 --- a/activerecord/lib/active_record/relation/calculations.rb +++ b/activerecord/lib/active_record/relation/calculations.rb @@ -309,7 +309,7 @@ module ActiveRecord type_cast_calculated_value(result.cast_values.first, operation) do |value| type = column.try(:type_caster) || - lookup_cast_type_from_join_dependencies(column_name.to_s, build_join_dependencies) || Type.default_value + lookup_cast_type_from_join_dependencies(column_name.to_s) || Type.default_value type.deserialize(value) end end @@ -388,7 +388,7 @@ module ActiveRecord result[key] = type_cast_calculated_value(row[column_alias], operation) do |value| type ||= column.try(:type_caster) || - lookup_cast_type_from_join_dependencies(column_name.to_s, build_join_dependencies) || Type.default_value + lookup_cast_type_from_join_dependencies(column_name.to_s) || Type.default_value type.deserialize(value) end end @@ -416,19 +416,10 @@ module ActiveRecord @klass.type_for_attribute(field_name, &block) end - def build_join_dependencies - join_dependencies = [] - join_dependencies.unshift construct_join_dependency( - select_association_list(joins_values + left_outer_joins_values, join_dependencies), nil - ) - end - - def lookup_cast_type_from_join_dependencies(name, join_dependencies) - join_dependencies.each do |join_dependency| - join_dependency.each do |join| - type = join.base_klass.attribute_types.fetch(name, nil) - return type if type - end + def lookup_cast_type_from_join_dependencies(name, join_dependencies = build_join_dependencies) + each_join_dependencies(join_dependencies) do |join| + type = join.base_klass.attribute_types.fetch(name, nil) + return type if type end nil end diff --git a/activerecord/lib/active_record/relation/predicate_builder.rb b/activerecord/lib/active_record/relation/predicate_builder.rb index 657634bc33..95b2634429 100644 --- a/activerecord/lib/active_record/relation/predicate_builder.rb +++ b/activerecord/lib/active_record/relation/predicate_builder.rb @@ -14,11 +14,11 @@ module ActiveRecord register_handler(Set, ArrayHandler.new(self)) end - def build_from_hash(attributes) + def build_from_hash(attributes, &block) attributes = attributes.stringify_keys attributes = convert_dot_notation_to_hash(attributes) - expand_from_hash(attributes) + expand_from_hash(attributes, &block) end def self.references(attributes) @@ -61,17 +61,18 @@ module ActiveRecord Arel::Nodes::BindParam.new(attr) end - def resolve_arel_attribute(table_name, column_name) - table.associated_table(table_name).arel_attribute(column_name) + def resolve_arel_attribute(table_name, column_name, &block) + table.associated_table(table_name, &block).arel_attribute(column_name) end protected - def expand_from_hash(attributes) + def expand_from_hash(attributes, &block) return ["1=0"] if attributes.empty? attributes.flat_map do |key, value| if value.is_a?(Hash) && !table.has_column?(key) - table.associated_predicate_builder(key).expand_from_hash(value) + table.associated_table(key, &block) + .predicate_builder.expand_from_hash(value.stringify_keys) elsif table.associated_with?(key) # Find the foreign key when using queries such as: # Post.where(author: author) diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index 2907e7a16f..7e18238f4c 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -1077,11 +1077,39 @@ module ActiveRecord def build_where_clause(opts, rest = []) # :nodoc: opts = sanitize_forbidden_attributes(opts) self.references_values |= PredicateBuilder.references(opts) if Hash === opts - where_clause_factory.build(opts, rest) + where_clause_factory.build(opts, rest) do |table_name| + lookup_reflection_from_join_dependencies(table_name) + end end alias :build_having_clause :build_where_clause private + def lookup_reflection_from_join_dependencies(table_name) + each_join_dependencies do |join| + return join.reflection if table_name == join.table_name + end + nil + end + + def each_join_dependencies(join_dependencies = build_join_dependencies) + join_dependencies.each do |join_dependency| + join_dependency.each do |join| + yield join + end + end + end + + def build_join_dependencies + associations = joins_values | left_outer_joins_values + associations |= eager_load_values unless eager_load_values.empty? + associations |= includes_values unless includes_values.empty? + + join_dependencies = [] + join_dependencies.unshift construct_join_dependency( + select_association_list(associations, join_dependencies), nil + ) + end + def assert_mutability! raise ImmutableRelation if @loaded raise ImmutableRelation if defined?(@arel) && @arel @@ -1268,7 +1296,9 @@ module ActiveRecord arel_attribute(field) elsif field.match?(/\A\w+\.\w+\z/) table, column = field.split(".") - predicate_builder.resolve_arel_attribute(table, column) + predicate_builder.resolve_arel_attribute(table, column) do + lookup_reflection_from_join_dependencies(table) + end else yield field end diff --git a/activerecord/lib/active_record/relation/where_clause_factory.rb b/activerecord/lib/active_record/relation/where_clause_factory.rb index 28baf80afe..f46ab6de38 100644 --- a/activerecord/lib/active_record/relation/where_clause_factory.rb +++ b/activerecord/lib/active_record/relation/where_clause_factory.rb @@ -8,12 +8,12 @@ module ActiveRecord @predicate_builder = predicate_builder end - def build(opts, other) + def build(opts, other, &block) case opts when String, Array parts = [klass.sanitize_sql(other.empty? ? opts : ([opts] + other))] when Hash - parts = predicate_builder.build_from_hash(opts) + parts = predicate_builder.build_from_hash(opts, &block) when Arel::Nodes::Node parts = [opts] else diff --git a/activerecord/lib/active_record/table_metadata.rb b/activerecord/lib/active_record/table_metadata.rb index 5ae7ec1240..402a903374 100644 --- a/activerecord/lib/active_record/table_metadata.rb +++ b/activerecord/lib/active_record/table_metadata.rb @@ -24,19 +24,23 @@ module ActiveRecord end def has_column?(column_name) - klass && klass.columns_hash.key?(column_name.to_s) + klass&.columns_hash.key?(column_name) end - def associated_with?(association_name) - klass && klass._reflect_on_association(association_name) + def associated_with?(table_name) + klass&._reflect_on_association(table_name) || klass&._reflect_on_association(table_name.singularize) end def associated_table(table_name) - reflection = klass._reflect_on_association(table_name) || klass._reflect_on_association(table_name.to_s.singularize) + reflection = klass._reflect_on_association(table_name) || klass._reflect_on_association(table_name.singularize) if !reflection && table_name == arel_table.name - self - elsif reflection && !reflection.polymorphic? + return self + end + + reflection ||= yield table_name if block_given? + + if reflection && !reflection.polymorphic? association_klass = reflection.klass arel_table = association_klass.arel_table.alias(table_name) TableMetadata.new(association_klass, arel_table, reflection) @@ -47,32 +51,24 @@ module ActiveRecord end end - def associated_predicate_builder(table_name) - associated_table(table_name).predicate_builder - end - def polymorphic_association? reflection&.polymorphic? end - def aggregated_with?(aggregation_name) - klass && reflect_on_aggregation(aggregation_name) - end - def reflect_on_aggregation(aggregation_name) - klass.reflect_on_aggregation(aggregation_name) + klass&.reflect_on_aggregation(aggregation_name) end + alias :aggregated_with? :reflect_on_aggregation - protected - def predicate_builder - if klass - predicate_builder = klass.predicate_builder.dup - predicate_builder.instance_variable_set(:@table, self) - predicate_builder - else - PredicateBuilder.new(self) - end + def predicate_builder + if klass + predicate_builder = klass.predicate_builder.dup + predicate_builder.instance_variable_set(:@table, self) + predicate_builder + else + PredicateBuilder.new(self) end + end private attr_reader :klass, :types, :arel_table, :reflection diff --git a/activerecord/test/cases/calculations_test.rb b/activerecord/test/cases/calculations_test.rb index 34fab5a159..a3656c1713 100644 --- a/activerecord/test/cases/calculations_test.rb +++ b/activerecord/test/cases/calculations_test.rb @@ -23,7 +23,7 @@ require "models/rating" require "support/stubs/strong_parameters" class CalculationsTest < ActiveRecord::TestCase - fixtures :companies, :accounts, :authors, :topics, :speedometers, :minivans, :books, :posts, :comments + fixtures :companies, :accounts, :authors, :author_addresses, :topics, :speedometers, :minivans, :books, :posts, :comments def test_should_sum_field assert_equal 318, Account.sum(:credit_limit) @@ -751,33 +751,35 @@ class CalculationsTest < ActiveRecord::TestCase [Date.new(2004, 4, 15), "reading"], [Date.new(2004, 4, 15), "read"], ] - actual = - Author.joins(:topics, :books).order(:"books.last_read") - .where.not("books.last_read": nil) + actual = AuthorAddress.joins(author: [:topics, :books]).order(:"books.last_read") + .where("books.last_read": [:unread, :reading, :read]) .pluck(:"topics.last_read", :"books.last_read") assert_equal expected, actual end def test_pluck_type_cast_with_joins_without_table_name_qualified_column - assert_pluck_type_cast_without_table_name_qualified_column(Author.joins(:books)) + assert_pluck_type_cast_without_table_name_qualified_column(AuthorAddress.joins(author: :books)) end def test_pluck_type_cast_with_left_joins_without_table_name_qualified_column - assert_pluck_type_cast_without_table_name_qualified_column(Author.left_joins(:books)) + assert_pluck_type_cast_without_table_name_qualified_column(AuthorAddress.left_joins(author: :books)) end def test_pluck_type_cast_with_eager_load_without_table_name_qualified_column - assert_pluck_type_cast_without_table_name_qualified_column(Author.eager_load(:books)) + assert_pluck_type_cast_without_table_name_qualified_column(AuthorAddress.eager_load(author: :books)) end - def assert_pluck_type_cast_without_table_name_qualified_column(authors) + def assert_pluck_type_cast_without_table_name_qualified_column(author_addresses) expected = [ [nil, "unread"], ["ebook", "reading"], ["paperback", "read"], ] - actual = authors.order(:last_read).where.not("books.last_read": nil).pluck(:format, :last_read) + actual = author_addresses.order(:last_read) + .where("books.last_read": [:unread, :reading, :read]) + .pluck(:format, :last_read) + assert_equal expected, actual end private :assert_pluck_type_cast_without_table_name_qualified_column diff --git a/activerecord/test/cases/inheritance_test.rb b/activerecord/test/cases/inheritance_test.rb index 01e4878c3f..a2a7d0ae93 100644 --- a/activerecord/test/cases/inheritance_test.rb +++ b/activerecord/test/cases/inheritance_test.rb @@ -488,8 +488,8 @@ class InheritanceTest < ActiveRecord::TestCase end def test_scope_inherited_properly - assert_nothing_raised { Company.of_first_firm } - assert_nothing_raised { Client.of_first_firm } + assert_nothing_raised { Company.of_first_firm.to_a } + assert_nothing_raised { Client.of_first_firm.to_a } end def test_inheritance_with_default_scope diff --git a/activerecord/test/models/company.rb b/activerecord/test/models/company.rb index 25f19c0fd4..0e4e0fdaab 100644 --- a/activerecord/test/models/company.rb +++ b/activerecord/test/models/company.rb @@ -9,6 +9,7 @@ class Company < AbstractCompany validates_presence_of :name + has_one :account, foreign_key: "firm_id" has_one :dummy_account, foreign_key: "firm_id", class_name: "Account" has_many :contracts has_many :developers, through: :contracts @@ -16,8 +17,7 @@ class Company < AbstractCompany attribute :metadata, :json scope :of_first_firm, lambda { - joins(account: :firm). - where("firms.id" => 1) + joins(account: :firm).where("companies.id": 1) } def arbitrary_method