diff --git a/lib/paper_trail/reifier.rb b/lib/paper_trail/reifier.rb index 8ad47e66..53e5379d 100644 --- a/lib/paper_trail/reifier.rb +++ b/lib/paper_trail/reifier.rb @@ -1,5 +1,6 @@ require "paper_trail/attribute_serializers/object_attribute" require "paper_trail/reifiers/belongs_to" +require "paper_trail/reifiers/has_and_belongs_to_many" require "paper_trail/reifiers/has_many" require "paper_trail/reifiers/has_many_through" require "paper_trail/reifiers/has_one" @@ -102,19 +103,6 @@ module PaperTrail (model.attribute_names - attrs.keys).each { |k| attrs[k] = nil } end - # Given a HABTM association `assoc` and an `id`, return a version record - # from the point in time identified by `transaction_id` or `version_at`. - # @api private - def load_version_for_habtm(assoc, id, transaction_id, version_at) - assoc.klass.paper_trail.version_class. - where("item_type = ?", assoc.klass.name). - where("item_id = ?", id). - where("created_at >= ? OR transaction_id = ?", version_at, transaction_id). - order("id"). - limit(1). - first - end - # Reify onto `model` an attribute named `k` with value `v` from `version`. # # `ObjectAttribute#deserialize` will return the mapped enum value and in @@ -200,47 +188,13 @@ module PaperTrail end end - # Reify a single HABTM association of `model`. - # @api private - def reify_habtm_association(assoc, model, options, papertrail_enabled, transaction_id) - version_ids = PaperTrail::VersionAssociation. - where("foreign_key_name = ?", assoc.name). - where("version_id = ?", transaction_id). - pluck(:foreign_key_id) - - model.send(assoc.name).proxy_association.target = - version_ids.map do |id| - if papertrail_enabled - version = load_version_for_habtm( - assoc, - id, - transaction_id, - options[:version_at] - ) - if version - next version.reify( - options.merge( - has_many: false, - has_one: false, - belongs_to: false, - has_and_belongs_to_many: false - ) - ) - end - end - assoc.klass.where(assoc.klass.primary_key => id).first - end - end - # Reify all HABTM associations of `model`. # @api private def reify_habtm_associations(transaction_id, model, options = {}) model.class.reflect_on_all_associations(:has_and_belongs_to_many).each do |assoc| - papertrail_enabled = assoc.klass.paper_trail.enabled? - next unless - model.class.paper_trail_save_join_tables.include?(assoc.name) || - papertrail_enabled - reify_habtm_association(assoc, model, options, papertrail_enabled, transaction_id) + pt_enabled = assoc.klass.paper_trail.enabled? + next unless model.class.paper_trail_save_join_tables.include?(assoc.name) || pt_enabled + Reifiers::HasAndBelongsToMany.reify(pt_enabled, assoc, model, options, transaction_id) end end diff --git a/lib/paper_trail/reifiers/has_and_belongs_to_many.rb b/lib/paper_trail/reifiers/has_and_belongs_to_many.rb new file mode 100644 index 00000000..4aa028b0 --- /dev/null +++ b/lib/paper_trail/reifiers/has_and_belongs_to_many.rb @@ -0,0 +1,50 @@ +module PaperTrail + module Reifiers + # Reify a single HABTM association of `model`. + # @api private + module HasAndBelongsToMany + class << self + # @api private + def reify(pt_enabled, assoc, model, options, transaction_id) + version_ids = ::PaperTrail::VersionAssociation. + where("foreign_key_name = ?", assoc.name). + where("version_id = ?", transaction_id). + pluck(:foreign_key_id) + + model.send(assoc.name).proxy_association.target = + version_ids.map do |id| + if pt_enabled + version = load_version(assoc, id, transaction_id, options[:version_at]) + if version + next version.reify( + options.merge( + has_many: false, + has_one: false, + belongs_to: false, + has_and_belongs_to_many: false + ) + ) + end + end + assoc.klass.where(assoc.klass.primary_key => id).first + end + end + + private + + # Given a HABTM association `assoc` and an `id`, return a version record + # from the point in time identified by `transaction_id` or `version_at`. + # @api private + def load_version(assoc, id, transaction_id, version_at) + assoc.klass.paper_trail.version_class. + where("item_type = ?", assoc.klass.name). + where("item_id = ?", id). + where("created_at >= ? OR transaction_id = ?", version_at, transaction_id). + order("id"). + limit(1). + first + end + end + end + end +end