diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index 863d7bb1aa..d2473188cb 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -41,23 +41,10 @@ module ActiveRecord # User.where.not(name: "Jon", role: "admin") # # SELECT * FROM users WHERE name != 'Jon' AND role != 'admin' def not(opts, *rest) - where_value = @scope.send(:build_where, opts, rest).map do |rel| - case rel - when NilClass - raise ArgumentError, 'Invalid argument for .where.not(), got nil.' - when Arel::Nodes::In - Arel::Nodes::NotIn.new(rel.left, rel.right) - when Arel::Nodes::Equality - Arel::Nodes::NotEqual.new(rel.left, rel.right) - when String - Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new(rel)) - else - Arel::Nodes::Not.new(rel) - end - end + where_clause = @scope.send(:where_clause_factory).build(opts, rest) @scope.references!(PredicateBuilder.references(opts)) if Hash === opts - @scope.where_values += where_value + @scope.where_clause += where_clause.invert @scope end end diff --git a/activerecord/lib/active_record/relation/where_clause.rb b/activerecord/lib/active_record/relation/where_clause.rb index d1469d0a7a..676b0d47d3 100644 --- a/activerecord/lib/active_record/relation/where_clause.rb +++ b/activerecord/lib/active_record/relation/where_clause.rb @@ -30,6 +30,10 @@ module ActiveRecord binds == other.binds end + def invert + WhereClause.new(inverted_parts, binds) + end + def self.empty new([], []) end @@ -60,6 +64,25 @@ module ActiveRecord conflicts.map! { |node| node.name.to_s } binds.reject { |col, _| conflicts.include?(col.name) } end + + def inverted_parts + parts.map { |node| invert_predicate(node) } + end + + def invert_predicate(node) + case node + when NilClass + raise ArgumentError, 'Invalid argument for .where.not(), got nil.' + when Arel::Nodes::In + Arel::Nodes::NotIn.new(node.left, node.right) + when Arel::Nodes::Equality + Arel::Nodes::NotEqual.new(node.left, node.right) + when String + Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new(node)) + else + Arel::Nodes::Not.new(node) + end + end end end end diff --git a/activerecord/test/cases/relation/where_clause_test.rb b/activerecord/test/cases/relation/where_clause_test.rb index b498739e84..569082541b 100644 --- a/activerecord/test/cases/relation/where_clause_test.rb +++ b/activerecord/test/cases/relation/where_clause_test.rb @@ -71,6 +71,32 @@ class ActiveRecord::Relation assert_not WhereClause.new(["anything"], []).empty? end + test "invert cannot handle nil" do + where_clause = WhereClause.new([nil], []) + + assert_raises ArgumentError do + where_clause.invert + end + end + + test "invert replaces each part of the predicate with its inverse" do + random_object = Object.new + original = WhereClause.new([ + table["id"].in([1, 2, 3]), + table["id"].eq(1), + "sql literal", + random_object + ], []) + expected = WhereClause.new([ + table["id"].not_in([1, 2, 3]), + table["id"].not_eq(1), + Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new("sql literal")), + Arel::Nodes::Not.new(random_object) + ], []) + + assert_equal expected, original.invert + end + private def table