gitlab-org--gitlab-foss/lib/gitlab/database/unidirectional_copy_trigger.rb

97 lines
3 KiB
Ruby

# frozen_string_literal: true
module Gitlab
module Database
class UnidirectionalCopyTrigger
def self.on_table(table_name, connection:)
new(table_name, connection)
end
def name(from_column_names, to_column_names)
from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names)
unchecked_name(from_column_names, to_column_names)
end
def create(from_column_names, to_column_names, trigger_name: nil)
from_column_names, to_column_names = check_column_names!(from_column_names, to_column_names)
trigger_name ||= unchecked_name(from_column_names, to_column_names)
assignment_clauses = assignment_clauses_for_columns(from_column_names, to_column_names)
connection.execute(<<~SQL)
CREATE OR REPLACE FUNCTION #{trigger_name}()
RETURNS trigger AS
$BODY$
BEGIN
#{assignment_clauses};
RETURN NEW;
END;
$BODY$
LANGUAGE 'plpgsql'
VOLATILE
SQL
connection.execute(<<~SQL)
DROP TRIGGER IF EXISTS #{trigger_name}
ON #{quoted_table_name}
SQL
connection.execute(<<~SQL)
CREATE TRIGGER #{trigger_name}
BEFORE INSERT OR UPDATE
ON #{quoted_table_name}
FOR EACH ROW
EXECUTE FUNCTION #{trigger_name}()
SQL
end
def drop(trigger_name)
connection.execute("DROP TRIGGER IF EXISTS #{trigger_name} ON #{quoted_table_name}")
connection.execute("DROP FUNCTION IF EXISTS #{trigger_name}()")
end
private
attr_reader :table_name, :connection
def initialize(table_name, connection)
@table_name = table_name
@connection = connection
end
def quoted_table_name
@quoted_table_name ||= connection.quote_table_name(table_name)
end
def check_column_names!(from_column_names, to_column_names)
from_column_names = Array.wrap(from_column_names)
to_column_names = Array.wrap(to_column_names)
unless from_column_names.size == to_column_names.size
raise ArgumentError, 'number of source and destination columns must match'
end
[from_column_names, to_column_names]
end
def unchecked_name(from_column_names, to_column_names)
joined_column_names = from_column_names.zip(to_column_names).flatten.join('_')
'trigger_' + Digest::SHA256.hexdigest("#{table_name}_#{joined_column_names}").first(12)
end
def assignment_clauses_for_columns(from_column_names, to_column_names)
combined_column_names = to_column_names.zip(from_column_names)
assignment_clauses = combined_column_names.map do |(new_name, old_name)|
new_name = connection.quote_column_name(new_name)
old_name = connection.quote_column_name(old_name)
"NEW.#{new_name} := NEW.#{old_name}"
end
assignment_clauses.join(";\n ")
end
end
end
end