diff --git a/lib/gitlab/database.rb b/lib/gitlab/database.rb index 43a00d6cedb..cd7b4c043da 100644 --- a/lib/gitlab/database.rb +++ b/lib/gitlab/database.rb @@ -108,20 +108,41 @@ module Gitlab end end - def self.bulk_insert(table, rows) + # Bulk inserts a number of rows into a table, optionally returning their + # IDs. + # + # table - The name of the table to insert the rows into. + # rows - An Array of Hash instances, each mapping the columns to their + # values. + # return_ids - When set to true the return value will be an Array of IDs of + # the inserted rows, this only works on PostgreSQL. + def self.bulk_insert(table, rows, return_ids: false) return if rows.empty? keys = rows.first.keys columns = keys.map { |key| connection.quote_column_name(key) } + return_ids = false if mysql? tuples = rows.map do |row| row.values_at(*keys).map { |value| connection.quote(value) } end - connection.execute <<-EOF + sql = <<-EOF INSERT INTO #{table} (#{columns.join(', ')}) VALUES #{tuples.map { |tuple| "(#{tuple.join(', ')})" }.join(', ')} EOF + + if return_ids + sql << 'RETURNING id' + end + + result = connection.execute(sql) + + if return_ids + result.values.map { |tuple| tuple[0].to_i } + else + [] + end end def self.sanitize_timestamp(timestamp) diff --git a/spec/lib/gitlab/database_spec.rb b/spec/lib/gitlab/database_spec.rb index 7aeb85b8f5a..fcddfad3f9f 100644 --- a/spec/lib/gitlab/database_spec.rb +++ b/spec/lib/gitlab/database_spec.rb @@ -202,6 +202,26 @@ describe Gitlab::Database do it 'handles non-UTF-8 data' do expect { described_class.bulk_insert('test', [{ a: "\255" }]) }.not_to raise_error end + + context 'when using PostgreSQL' do + before do + allow(described_class).to receive(:mysql?).and_return(false) + end + + it 'allows the returning of the IDs of the inserted rows' do + result = double(:result, values: [['10']]) + + expect(connection) + .to receive(:execute) + .with(/RETURNING id/) + .and_return(result) + + ids = described_class + .bulk_insert('test', [{ number: 10 }], return_ids: true) + + expect(ids).to eq([10]) + end + end end describe '.create_connection_pool' do