mirror of
https://github.com/rails/rails.git
synced 2022-11-09 12:12:34 -05:00
Modularize postgresql adapter
This commit is contained in:
parent
5f99bdbec2
commit
232d2223eb
7 changed files with 1040 additions and 990 deletions
|
@ -0,0 +1,80 @@
|
|||
module ActiveRecord
|
||||
module ConnectionAdapters
|
||||
class PostgreSQLColumn < Column
|
||||
module Cast
|
||||
def string_to_time(string)
|
||||
return string unless String === string
|
||||
|
||||
case string
|
||||
when 'infinity'; 1.0 / 0.0
|
||||
when '-infinity'; -1.0 / 0.0
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
def hstore_to_string(object)
|
||||
if Hash === object
|
||||
object.map { |k,v|
|
||||
"#{escape_hstore(k)}=>#{escape_hstore(v)}"
|
||||
}.join ','
|
||||
else
|
||||
object
|
||||
end
|
||||
end
|
||||
|
||||
def string_to_hstore(string)
|
||||
if string.nil?
|
||||
nil
|
||||
elsif String === string
|
||||
Hash[string.scan(HstorePair).map { |k,v|
|
||||
v = v.upcase == 'NULL' ? nil : v.gsub(/^"(.*)"$/,'\1').gsub(/\\(.)/, '\1')
|
||||
k = k.gsub(/^"(.*)"$/,'\1').gsub(/\\(.)/, '\1')
|
||||
[k,v]
|
||||
}]
|
||||
else
|
||||
string
|
||||
end
|
||||
end
|
||||
|
||||
def string_to_cidr(string)
|
||||
if string.nil?
|
||||
nil
|
||||
elsif String === string
|
||||
IPAddr.new(string)
|
||||
else
|
||||
string
|
||||
end
|
||||
end
|
||||
|
||||
def cidr_to_string(object)
|
||||
if IPAddr === object
|
||||
"#{object.to_s}/#{object.instance_variable_get(:@mask_addr).to_s(2).count('1')}"
|
||||
else
|
||||
object
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
HstorePair = begin
|
||||
quoted_string = /"[^"\\]*(?:\\.[^"\\]*)*"/
|
||||
unquoted_string = /(?:\\.|[^\s,])[^\s=,\\]*(?:\\.[^\s=,\\]*|=[^,>])*/
|
||||
/(#{quoted_string}|#{unquoted_string})\s*=>\s*(#{quoted_string}|#{unquoted_string})/
|
||||
end
|
||||
|
||||
def escape_hstore(value)
|
||||
if value.nil?
|
||||
'NULL'
|
||||
else
|
||||
if value == ""
|
||||
'""'
|
||||
else
|
||||
'"%s"' % value.to_s.gsub(/(["\\])/, '\\\\\1')
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,234 @@
|
|||
module ActiveRecord
|
||||
module ConnectionAdapters
|
||||
class PostgreSQLAdapter < AbstractAdapter
|
||||
module DatabaseStatements
|
||||
def explain(arel, binds = [])
|
||||
sql = "EXPLAIN #{to_sql(arel, binds)}"
|
||||
ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds))
|
||||
end
|
||||
|
||||
class ExplainPrettyPrinter # :nodoc:
|
||||
# Pretty prints the result of a EXPLAIN in a way that resembles the output of the
|
||||
# PostgreSQL shell:
|
||||
#
|
||||
# QUERY PLAN
|
||||
# ------------------------------------------------------------------------------
|
||||
# Nested Loop Left Join (cost=0.00..37.24 rows=8 width=0)
|
||||
# Join Filter: (posts.user_id = users.id)
|
||||
# -> Index Scan using users_pkey on users (cost=0.00..8.27 rows=1 width=4)
|
||||
# Index Cond: (id = 1)
|
||||
# -> Seq Scan on posts (cost=0.00..28.88 rows=8 width=4)
|
||||
# Filter: (posts.user_id = 1)
|
||||
# (6 rows)
|
||||
#
|
||||
def pp(result)
|
||||
header = result.columns.first
|
||||
lines = result.rows.map(&:first)
|
||||
|
||||
# We add 2 because there's one char of padding at both sides, note
|
||||
# the extra hyphens in the example above.
|
||||
width = [header, *lines].map(&:length).max + 2
|
||||
|
||||
pp = []
|
||||
|
||||
pp << header.center(width).rstrip
|
||||
pp << '-' * width
|
||||
|
||||
pp += lines.map {|line| " #{line}"}
|
||||
|
||||
nrows = result.rows.length
|
||||
rows_label = nrows == 1 ? 'row' : 'rows'
|
||||
pp << "(#{nrows} #{rows_label})"
|
||||
|
||||
pp.join("\n") + "\n"
|
||||
end
|
||||
end
|
||||
|
||||
# Executes a SELECT query and returns an array of rows. Each row is an
|
||||
# array of field values.
|
||||
def select_rows(sql, name = nil)
|
||||
select_raw(sql, name).last
|
||||
end
|
||||
|
||||
# Executes an INSERT query and returns the new record's ID
|
||||
def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
|
||||
unless pk
|
||||
# Extract the table from the insert sql. Yuck.
|
||||
table_ref = extract_table_ref_from_insert_sql(sql)
|
||||
pk = primary_key(table_ref) if table_ref
|
||||
end
|
||||
|
||||
if pk && use_insert_returning?
|
||||
select_value("#{sql} RETURNING #{quote_column_name(pk)}")
|
||||
elsif pk
|
||||
super
|
||||
last_insert_id_value(sequence_name || default_sequence_name(table_ref, pk))
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
def create
|
||||
super.insert
|
||||
end
|
||||
|
||||
# create a 2D array representing the result set
|
||||
def result_as_array(res) #:nodoc:
|
||||
# check if we have any binary column and if they need escaping
|
||||
ftypes = Array.new(res.nfields) do |i|
|
||||
[i, res.ftype(i)]
|
||||
end
|
||||
|
||||
rows = res.values
|
||||
return rows unless ftypes.any? { |_, x|
|
||||
x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
|
||||
}
|
||||
|
||||
typehash = ftypes.group_by { |_, type| type }
|
||||
binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
|
||||
monies = typehash[MONEY_COLUMN_TYPE_OID] || []
|
||||
|
||||
rows.each do |row|
|
||||
# unescape string passed BYTEA field (OID == 17)
|
||||
binaries.each do |index, _|
|
||||
row[index] = unescape_bytea(row[index])
|
||||
end
|
||||
|
||||
# If this is a money type column and there are any currency symbols,
|
||||
# then strip them off. Indeed it would be prettier to do this in
|
||||
# PostgreSQLColumn.string_to_decimal but would break form input
|
||||
# fields that call value_before_type_cast.
|
||||
monies.each do |index, _|
|
||||
data = row[index]
|
||||
# Because money output is formatted according to the locale, there are two
|
||||
# cases to consider (note the decimal separators):
|
||||
# (1) $12,345,678.12
|
||||
# (2) $12.345.678,12
|
||||
case data
|
||||
when /^-?\D+[\d,]+\.\d{2}$/ # (1)
|
||||
data.gsub!(/[^-\d.]/, '')
|
||||
when /^-?\D+[\d.]+,\d{2}$/ # (2)
|
||||
data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Queries the database and returns the results in an Array-like object
|
||||
def query(sql, name = nil) #:nodoc:
|
||||
log(sql, name) do
|
||||
result_as_array @connection.async_exec(sql)
|
||||
end
|
||||
end
|
||||
|
||||
# Executes an SQL statement, returning a PGresult object on success
|
||||
# or raising a PGError exception otherwise.
|
||||
def execute(sql, name = nil)
|
||||
log(sql, name) do
|
||||
@connection.async_exec(sql)
|
||||
end
|
||||
end
|
||||
|
||||
def substitute_at(column, index)
|
||||
Arel::Nodes::BindParam.new "$#{index + 1}"
|
||||
end
|
||||
|
||||
def exec_query(sql, name = 'SQL', binds = [])
|
||||
log(sql, name, binds) do
|
||||
result = binds.empty? ? exec_no_cache(sql, binds) :
|
||||
exec_cache(sql, binds)
|
||||
|
||||
types = {}
|
||||
result.fields.each_with_index do |fname, i|
|
||||
ftype = result.ftype i
|
||||
fmod = result.fmod i
|
||||
types[fname] = OID::TYPE_MAP.fetch(ftype, fmod) { |oid, mod|
|
||||
warn "unknown OID: #{fname}(#{oid}) (#{sql})"
|
||||
OID::Identity.new
|
||||
}
|
||||
end
|
||||
|
||||
ret = ActiveRecord::Result.new(result.fields, result.values, types)
|
||||
result.clear
|
||||
return ret
|
||||
end
|
||||
end
|
||||
|
||||
def exec_delete(sql, name = 'SQL', binds = [])
|
||||
log(sql, name, binds) do
|
||||
result = binds.empty? ? exec_no_cache(sql, binds) :
|
||||
exec_cache(sql, binds)
|
||||
affected = result.cmd_tuples
|
||||
result.clear
|
||||
affected
|
||||
end
|
||||
end
|
||||
alias :exec_update :exec_delete
|
||||
|
||||
def sql_for_insert(sql, pk, id_value, sequence_name, binds)
|
||||
unless pk
|
||||
# Extract the table from the insert sql. Yuck.
|
||||
table_ref = extract_table_ref_from_insert_sql(sql)
|
||||
pk = primary_key(table_ref) if table_ref
|
||||
end
|
||||
|
||||
if pk && use_insert_returning?
|
||||
sql = "#{sql} RETURNING #{quote_column_name(pk)}"
|
||||
end
|
||||
|
||||
[sql, binds]
|
||||
end
|
||||
|
||||
def exec_insert(sql, name, binds, pk = nil, sequence_name = nil)
|
||||
val = exec_query(sql, name, binds)
|
||||
if !use_insert_returning? && pk
|
||||
unless sequence_name
|
||||
table_ref = extract_table_ref_from_insert_sql(sql)
|
||||
sequence_name = default_sequence_name(table_ref, pk)
|
||||
return val unless sequence_name
|
||||
end
|
||||
last_insert_id_result(sequence_name)
|
||||
else
|
||||
val
|
||||
end
|
||||
end
|
||||
|
||||
# Executes an UPDATE query and returns the number of affected tuples.
|
||||
def update_sql(sql, name = nil)
|
||||
super.cmd_tuples
|
||||
end
|
||||
|
||||
# Begins a transaction.
|
||||
def begin_db_transaction
|
||||
execute "BEGIN"
|
||||
end
|
||||
|
||||
# Commits a transaction.
|
||||
def commit_db_transaction
|
||||
execute "COMMIT"
|
||||
end
|
||||
|
||||
# Aborts a transaction.
|
||||
def rollback_db_transaction
|
||||
execute "ROLLBACK"
|
||||
end
|
||||
|
||||
def outside_transaction?
|
||||
@connection.transaction_status == PGconn::PQTRANS_IDLE
|
||||
end
|
||||
|
||||
def create_savepoint
|
||||
execute("SAVEPOINT #{current_savepoint_name}")
|
||||
end
|
||||
|
||||
def rollback_to_savepoint
|
||||
execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
|
||||
end
|
||||
|
||||
def release_savepoint
|
||||
execute("RELEASE SAVEPOINT #{current_savepoint_name}")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,120 @@
|
|||
module ActiveRecord
|
||||
module ConnectionAdapters
|
||||
class PostgreSQLAdapter < AbstractAdapter
|
||||
module Quoting
|
||||
# Escapes binary strings for bytea input to the database.
|
||||
def escape_bytea(value)
|
||||
PGconn.escape_bytea(value) if value
|
||||
end
|
||||
|
||||
# Unescapes bytea output from a database to the binary string it represents.
|
||||
# NOTE: This is NOT an inverse of escape_bytea! This is only to be used
|
||||
# on escaped binary output from database drive.
|
||||
def unescape_bytea(value)
|
||||
PGconn.unescape_bytea(value) if value
|
||||
end
|
||||
|
||||
# Quotes PostgreSQL-specific data types for SQL input.
|
||||
def quote(value, column = nil) #:nodoc:
|
||||
return super unless column
|
||||
|
||||
case value
|
||||
when Hash
|
||||
case column.sql_type
|
||||
when 'hstore' then super(PostgreSQLColumn.hstore_to_string(value), column)
|
||||
else super
|
||||
end
|
||||
when IPAddr
|
||||
case column.sql_type
|
||||
when 'inet', 'cidr' then super(PostgreSQLColumn.cidr_to_string(value), column)
|
||||
else super
|
||||
end
|
||||
when Float
|
||||
if value.infinite? && column.type == :datetime
|
||||
"'#{value.to_s.downcase}'"
|
||||
elsif value.infinite? || value.nan?
|
||||
"'#{value.to_s}'"
|
||||
else
|
||||
super
|
||||
end
|
||||
when Numeric
|
||||
return super unless column.sql_type == 'money'
|
||||
# Not truly string input, so doesn't require (or allow) escape string syntax.
|
||||
"'#{value}'"
|
||||
when String
|
||||
case column.sql_type
|
||||
when 'bytea' then "'#{escape_bytea(value)}'"
|
||||
when 'xml' then "xml '#{quote_string(value)}'"
|
||||
when /^bit/
|
||||
case value
|
||||
when /^[01]*$/ then "B'#{value}'" # Bit-string notation
|
||||
when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
|
||||
end
|
||||
else
|
||||
super
|
||||
end
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
def type_cast(value, column)
|
||||
return super unless column
|
||||
|
||||
case value
|
||||
when String
|
||||
return super unless 'bytea' == column.sql_type
|
||||
{ :value => value, :format => 1 }
|
||||
when Hash
|
||||
return super unless 'hstore' == column.sql_type
|
||||
PostgreSQLColumn.hstore_to_string(value)
|
||||
when IPAddr
|
||||
return super unless ['inet','cidr'].includes? column.sql_type
|
||||
PostgreSQLColumn.cidr_to_string(value)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
# Quotes strings for use in SQL input.
|
||||
def quote_string(s) #:nodoc:
|
||||
@connection.escape(s)
|
||||
end
|
||||
|
||||
# Checks the following cases:
|
||||
#
|
||||
# - table_name
|
||||
# - "table.name"
|
||||
# - schema_name.table_name
|
||||
# - schema_name."table.name"
|
||||
# - "schema.name".table_name
|
||||
# - "schema.name"."table.name"
|
||||
def quote_table_name(name)
|
||||
schema, name_part = extract_pg_identifier_from_name(name.to_s)
|
||||
|
||||
unless name_part
|
||||
quote_column_name(schema)
|
||||
else
|
||||
table_name, name_part = extract_pg_identifier_from_name(name_part)
|
||||
"#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
|
||||
end
|
||||
end
|
||||
|
||||
# Quotes column names for use in SQL queries.
|
||||
def quote_column_name(name) #:nodoc:
|
||||
PGconn.quote_ident(name.to_s)
|
||||
end
|
||||
|
||||
# Quote date/time values for use in SQL input. Includes microseconds
|
||||
# if the value is a Time responding to usec.
|
||||
def quoted_date(value) #:nodoc:
|
||||
if value.acts_like?(:time) && value.respond_to?(:usec)
|
||||
"#{super}.#{sprintf("%06d", value.usec)}"
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,22 @@
|
|||
module ActiveRecord
|
||||
module ConnectionAdapters
|
||||
class PostgreSQLAdapter < AbstractAdapter
|
||||
module ReferentialIntegrity
|
||||
def supports_disable_referential_integrity? #:nodoc:
|
||||
true
|
||||
end
|
||||
|
||||
def disable_referential_integrity #:nodoc:
|
||||
if supports_disable_referential_integrity? then
|
||||
execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
|
||||
end
|
||||
yield
|
||||
ensure
|
||||
if supports_disable_referential_integrity? then
|
||||
execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,446 @@
|
|||
module ActiveRecord
|
||||
module ConnectionAdapters
|
||||
class PostgreSQLAdapter < AbstractAdapter
|
||||
module SchemaStatements
|
||||
# Drops the database specified on the +name+ attribute
|
||||
# and creates it again using the provided +options+.
|
||||
def recreate_database(name, options = {}) #:nodoc:
|
||||
drop_database(name)
|
||||
create_database(name, options)
|
||||
end
|
||||
|
||||
# Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
|
||||
# <tt>:encoding</tt>, <tt>:collation</tt>, <tt>:ctype</tt>,
|
||||
# <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
|
||||
# <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
|
||||
#
|
||||
# Example:
|
||||
# create_database config[:database], config
|
||||
# create_database 'foo_development', :encoding => 'unicode'
|
||||
def create_database(name, options = {})
|
||||
options = options.reverse_merge(:encoding => "utf8")
|
||||
|
||||
option_string = options.symbolize_keys.sum do |key, value|
|
||||
case key
|
||||
when :owner
|
||||
" OWNER = \"#{value}\""
|
||||
when :template
|
||||
" TEMPLATE = \"#{value}\""
|
||||
when :encoding
|
||||
" ENCODING = '#{value}'"
|
||||
when :collation
|
||||
" LC_COLLATE = '#{value}'"
|
||||
when :ctype
|
||||
" LC_CTYPE = '#{value}'"
|
||||
when :tablespace
|
||||
" TABLESPACE = \"#{value}\""
|
||||
when :connection_limit
|
||||
" CONNECTION LIMIT = #{value}"
|
||||
else
|
||||
""
|
||||
end
|
||||
end
|
||||
|
||||
execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
|
||||
end
|
||||
|
||||
# Drops a PostgreSQL database.
|
||||
#
|
||||
# Example:
|
||||
# drop_database 'matt_development'
|
||||
def drop_database(name) #:nodoc:
|
||||
execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
|
||||
end
|
||||
|
||||
# Returns the list of all tables in the schema search path or a specified schema.
|
||||
def tables(name = nil)
|
||||
query(<<-SQL, 'SCHEMA').map { |row| row[0] }
|
||||
SELECT tablename
|
||||
FROM pg_tables
|
||||
WHERE schemaname = ANY (current_schemas(false))
|
||||
SQL
|
||||
end
|
||||
|
||||
# Returns true if table exists.
|
||||
# If the schema is not specified as part of +name+ then it will only find tables within
|
||||
# the current schema search path (regardless of permissions to access tables in other schemas)
|
||||
def table_exists?(name)
|
||||
schema, table = Utils.extract_schema_and_table(name.to_s)
|
||||
return false unless table
|
||||
|
||||
binds = [[nil, table]]
|
||||
binds << [nil, schema] if schema
|
||||
|
||||
exec_query(<<-SQL, 'SCHEMA').rows.first[0].to_i > 0
|
||||
SELECT COUNT(*)
|
||||
FROM pg_class c
|
||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind in ('v','r')
|
||||
AND c.relname = '#{table.gsub(/(^"|"$)/,'')}'
|
||||
AND n.nspname = #{schema ? "'#{schema}'" : 'ANY (current_schemas(false))'}
|
||||
SQL
|
||||
end
|
||||
|
||||
# Returns true if schema exists.
|
||||
def schema_exists?(name)
|
||||
exec_query(<<-SQL, 'SCHEMA').rows.first[0].to_i > 0
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace
|
||||
WHERE nspname = '#{name}'
|
||||
SQL
|
||||
end
|
||||
|
||||
# Returns an array of indexes for the given table.
|
||||
def indexes(table_name, name = nil)
|
||||
result = query(<<-SQL, 'SCHEMA')
|
||||
SELECT distinct i.relname, d.indisunique, d.indkey, pg_get_indexdef(d.indexrelid), t.oid
|
||||
FROM pg_class t
|
||||
INNER JOIN pg_index d ON t.oid = d.indrelid
|
||||
INNER JOIN pg_class i ON d.indexrelid = i.oid
|
||||
WHERE i.relkind = 'i'
|
||||
AND d.indisprimary = 'f'
|
||||
AND t.relname = '#{table_name}'
|
||||
AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname = ANY (current_schemas(false)) )
|
||||
ORDER BY i.relname
|
||||
SQL
|
||||
|
||||
result.map do |row|
|
||||
index_name = row[0]
|
||||
unique = row[1] == 't'
|
||||
indkey = row[2].split(" ")
|
||||
inddef = row[3]
|
||||
oid = row[4]
|
||||
|
||||
columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
|
||||
SELECT a.attnum, a.attname
|
||||
FROM pg_attribute a
|
||||
WHERE a.attrelid = #{oid}
|
||||
AND a.attnum IN (#{indkey.join(",")})
|
||||
SQL
|
||||
|
||||
column_names = columns.values_at(*indkey).compact
|
||||
|
||||
# add info on sort order for columns (only desc order is explicitly specified, asc is the default)
|
||||
desc_order_columns = inddef.scan(/(\w+) DESC/).flatten
|
||||
orders = desc_order_columns.any? ? Hash[desc_order_columns.map {|order_column| [order_column, :desc]}] : {}
|
||||
where = inddef.scan(/WHERE (.+)$/).flatten[0]
|
||||
|
||||
column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names, [], orders, where)
|
||||
end.compact
|
||||
end
|
||||
|
||||
# Returns the list of all column definitions for a table.
|
||||
def columns(table_name)
|
||||
# Limit, precision, and scale are all handled by the superclass.
|
||||
column_definitions(table_name).map do |column_name, type, default, notnull, oid, fmod|
|
||||
oid = OID::TYPE_MAP.fetch(oid.to_i, fmod.to_i) {
|
||||
OID::Identity.new
|
||||
}
|
||||
PostgreSQLColumn.new(column_name, default, oid, type, notnull == 'f')
|
||||
end
|
||||
end
|
||||
|
||||
# Returns the current database name.
|
||||
def current_database
|
||||
query('select current_database()', 'SCHEMA')[0][0]
|
||||
end
|
||||
|
||||
# Returns the current schema name.
|
||||
def current_schema
|
||||
query('SELECT current_schema', 'SCHEMA')[0][0]
|
||||
end
|
||||
|
||||
# Returns the current database encoding format.
|
||||
def encoding
|
||||
query(<<-end_sql, 'SCHEMA')[0][0]
|
||||
SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
|
||||
WHERE pg_database.datname LIKE '#{current_database}'
|
||||
end_sql
|
||||
end
|
||||
|
||||
# Returns the current database collation.
|
||||
def collation
|
||||
query(<<-end_sql, 'SCHEMA')[0][0]
|
||||
SELECT pg_database.datcollate FROM pg_database WHERE pg_database.datname LIKE '#{current_database}'
|
||||
end_sql
|
||||
end
|
||||
|
||||
# Returns the current database ctype.
|
||||
def ctype
|
||||
query(<<-end_sql, 'SCHEMA')[0][0]
|
||||
SELECT pg_database.datctype FROM pg_database WHERE pg_database.datname LIKE '#{current_database}'
|
||||
end_sql
|
||||
end
|
||||
|
||||
# Returns an array of schema names.
|
||||
def schema_names
|
||||
query(<<-SQL, 'SCHEMA').flatten
|
||||
SELECT nspname
|
||||
FROM pg_namespace
|
||||
WHERE nspname !~ '^pg_.*'
|
||||
AND nspname NOT IN ('information_schema')
|
||||
ORDER by nspname;
|
||||
SQL
|
||||
end
|
||||
|
||||
# Creates a schema for the given schema name.
|
||||
def create_schema schema_name
|
||||
execute "CREATE SCHEMA #{schema_name}"
|
||||
end
|
||||
|
||||
# Drops the schema for the given schema name.
|
||||
def drop_schema schema_name
|
||||
execute "DROP SCHEMA #{schema_name} CASCADE"
|
||||
end
|
||||
|
||||
# Sets the schema search path to a string of comma-separated schema names.
|
||||
# Names beginning with $ have to be quoted (e.g. $user => '$user').
|
||||
# See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
|
||||
#
|
||||
# This should be not be called manually but set in database.yml.
|
||||
def schema_search_path=(schema_csv)
|
||||
if schema_csv
|
||||
execute("SET search_path TO #{schema_csv}", 'SCHEMA')
|
||||
@schema_search_path = schema_csv
|
||||
end
|
||||
end
|
||||
|
||||
# Returns the active schema search path.
|
||||
def schema_search_path
|
||||
@schema_search_path ||= query('SHOW search_path', 'SCHEMA')[0][0]
|
||||
end
|
||||
|
||||
# Returns the current client message level.
|
||||
def client_min_messages
|
||||
query('SHOW client_min_messages', 'SCHEMA')[0][0]
|
||||
end
|
||||
|
||||
# Set the client message level.
|
||||
def client_min_messages=(level)
|
||||
execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
|
||||
end
|
||||
|
||||
# Returns the sequence name for a table's primary key or some other specified key.
|
||||
def default_sequence_name(table_name, pk = nil) #:nodoc:
|
||||
result = serial_sequence(table_name, pk || 'id')
|
||||
return nil unless result
|
||||
result.split('.').last
|
||||
rescue ActiveRecord::StatementInvalid
|
||||
"#{table_name}_#{pk || 'id'}_seq"
|
||||
end
|
||||
|
||||
def serial_sequence(table, column)
|
||||
result = exec_query(<<-eosql, 'SCHEMA')
|
||||
SELECT pg_get_serial_sequence('#{table}', '#{column}')
|
||||
eosql
|
||||
result.rows.first.first
|
||||
end
|
||||
|
||||
# Resets the sequence of a table's primary key to the maximum value.
|
||||
def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
|
||||
unless pk and sequence
|
||||
default_pk, default_sequence = pk_and_sequence_for(table)
|
||||
|
||||
pk ||= default_pk
|
||||
sequence ||= default_sequence
|
||||
end
|
||||
|
||||
if @logger && pk && !sequence
|
||||
@logger.warn "#{table} has primary key #{pk} with no default sequence"
|
||||
end
|
||||
|
||||
if pk && sequence
|
||||
quoted_sequence = quote_table_name(sequence)
|
||||
|
||||
select_value <<-end_sql, 'Reset sequence'
|
||||
SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
|
||||
end_sql
|
||||
end
|
||||
end
|
||||
|
||||
# Returns a table's primary key and belonging sequence.
|
||||
def pk_and_sequence_for(table) #:nodoc:
|
||||
# First try looking for a sequence with a dependency on the
|
||||
# given table's primary key.
|
||||
result = query(<<-end_sql, 'PK and serial sequence')[0]
|
||||
SELECT attr.attname, seq.relname
|
||||
FROM pg_class seq,
|
||||
pg_attribute attr,
|
||||
pg_depend dep,
|
||||
pg_namespace name,
|
||||
pg_constraint cons
|
||||
WHERE seq.oid = dep.objid
|
||||
AND seq.relkind = 'S'
|
||||
AND attr.attrelid = dep.refobjid
|
||||
AND attr.attnum = dep.refobjsubid
|
||||
AND attr.attrelid = cons.conrelid
|
||||
AND attr.attnum = cons.conkey[1]
|
||||
AND cons.contype = 'p'
|
||||
AND dep.refobjid = '#{quote_table_name(table)}'::regclass
|
||||
end_sql
|
||||
|
||||
if result.nil? or result.empty?
|
||||
# If that fails, try parsing the primary key's default value.
|
||||
# Support the 7.x and 8.0 nextval('foo'::text) as well as
|
||||
# the 8.1+ nextval('foo'::regclass).
|
||||
result = query(<<-end_sql, 'PK and custom sequence')[0]
|
||||
SELECT attr.attname,
|
||||
CASE
|
||||
WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
|
||||
substr(split_part(def.adsrc, '''', 2),
|
||||
strpos(split_part(def.adsrc, '''', 2), '.')+1)
|
||||
ELSE split_part(def.adsrc, '''', 2)
|
||||
END
|
||||
FROM pg_class t
|
||||
JOIN pg_attribute attr ON (t.oid = attrelid)
|
||||
JOIN pg_attrdef def ON (adrelid = attrelid AND adnum = attnum)
|
||||
JOIN pg_constraint cons ON (conrelid = adrelid AND adnum = conkey[1])
|
||||
WHERE t.oid = '#{quote_table_name(table)}'::regclass
|
||||
AND cons.contype = 'p'
|
||||
AND def.adsrc ~* 'nextval'
|
||||
end_sql
|
||||
end
|
||||
|
||||
[result.first, result.last]
|
||||
rescue
|
||||
nil
|
||||
end
|
||||
|
||||
# Returns just a table's primary key
|
||||
def primary_key(table)
|
||||
row = exec_query(<<-end_sql, 'SCHEMA').rows.first
|
||||
SELECT DISTINCT(attr.attname)
|
||||
FROM pg_attribute attr
|
||||
INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
|
||||
INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
|
||||
WHERE cons.contype = 'p'
|
||||
AND dep.refobjid = '#{table}'::regclass
|
||||
end_sql
|
||||
|
||||
row && row.first
|
||||
end
|
||||
|
||||
# Renames a table.
|
||||
# Also renames a table's primary key sequence if the sequence name matches the
|
||||
# Active Record default.
|
||||
#
|
||||
# Example:
|
||||
# rename_table('octopuses', 'octopi')
|
||||
def rename_table(name, new_name)
|
||||
clear_cache!
|
||||
execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
|
||||
pk, seq = pk_and_sequence_for(new_name)
|
||||
if seq == "#{name}_#{pk}_seq"
|
||||
new_seq = "#{new_name}_#{pk}_seq"
|
||||
execute "ALTER TABLE #{quote_table_name(seq)} RENAME TO #{quote_table_name(new_seq)}"
|
||||
end
|
||||
end
|
||||
|
||||
# Adds a new column to the named table.
|
||||
# See TableDefinition#column for details of the options you can use.
|
||||
def add_column(table_name, column_name, type, options = {})
|
||||
clear_cache!
|
||||
add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
|
||||
add_column_options!(add_column_sql, options)
|
||||
|
||||
execute add_column_sql
|
||||
end
|
||||
|
||||
# Changes the column of a table.
|
||||
def change_column(table_name, column_name, type, options = {})
|
||||
clear_cache!
|
||||
quoted_table_name = quote_table_name(table_name)
|
||||
|
||||
execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
|
||||
|
||||
change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
|
||||
change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
|
||||
end
|
||||
|
||||
# Changes the default value of a table column.
|
||||
def change_column_default(table_name, column_name, default)
|
||||
clear_cache!
|
||||
execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
|
||||
end
|
||||
|
||||
def change_column_null(table_name, column_name, null, default = nil)
|
||||
clear_cache!
|
||||
unless null || default.nil?
|
||||
execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
|
||||
end
|
||||
execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
|
||||
end
|
||||
|
||||
# Renames a column in a table.
|
||||
def rename_column(table_name, column_name, new_column_name)
|
||||
clear_cache!
|
||||
execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
|
||||
end
|
||||
|
||||
def remove_index!(table_name, index_name) #:nodoc:
|
||||
execute "DROP INDEX #{quote_table_name(index_name)}"
|
||||
end
|
||||
|
||||
def rename_index(table_name, old_name, new_name)
|
||||
execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
|
||||
end
|
||||
|
||||
def index_name_length
|
||||
63
|
||||
end
|
||||
|
||||
# Maps logical Rails types to PostgreSQL-specific data types.
|
||||
def type_to_sql(type, limit = nil, precision = nil, scale = nil)
|
||||
case type.to_s
|
||||
when 'binary'
|
||||
# PostgreSQL doesn't support limits on binary (bytea) columns.
|
||||
# The hard limit is 1Gb, because of a 32-bit size field, and TOAST.
|
||||
case limit
|
||||
when nil, 0..0x3fffffff; super(type)
|
||||
else raise(ActiveRecordError, "No binary type has byte size #{limit}.")
|
||||
end
|
||||
when 'integer'
|
||||
return 'integer' unless limit
|
||||
|
||||
case limit
|
||||
when 1, 2; 'smallint'
|
||||
when 3, 4; 'integer'
|
||||
when 5..8; 'bigint'
|
||||
else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
|
||||
end
|
||||
when 'datetime'
|
||||
return super unless precision
|
||||
|
||||
case precision
|
||||
when 0..6; "timestamp(#{precision})"
|
||||
else raise(ActiveRecordError, "No timestamp type has precision of #{precision}. The allowed range of precision is from 0 to 6")
|
||||
end
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
# Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
|
||||
#
|
||||
# PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
|
||||
# requires that the ORDER BY include the distinct column.
|
||||
#
|
||||
# distinct("posts.id", "posts.created_at desc")
|
||||
def distinct(columns, orders) #:nodoc:
|
||||
return "DISTINCT #{columns}" if orders.empty?
|
||||
|
||||
# Construct a clean list of column names from the ORDER BY clause, removing
|
||||
# any ASC/DESC modifiers
|
||||
order_columns = orders.collect do |s|
|
||||
s = s.to_sql unless s.is_a?(String)
|
||||
s.gsub(/\s+(ASC|DESC)\s*(NULLS\s+(FIRST|LAST)\s*)?/i, '')
|
||||
end
|
||||
order_columns.delete_if { |c| c.blank? }
|
||||
order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
|
||||
|
||||
"DISTINCT #{columns}, #{order_columns * ', '}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
File diff suppressed because it is too large
Load diff
|
@ -3,8 +3,6 @@ require 'cases/helper'
|
|||
class PostgresqlActiveSchemaTest < ActiveRecord::TestCase
|
||||
def setup
|
||||
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.class_eval do
|
||||
alias_method :real_execute, :execute
|
||||
remove_method :execute
|
||||
def execute(sql, name = nil) sql end
|
||||
end
|
||||
end
|
||||
|
@ -12,7 +10,6 @@ class PostgresqlActiveSchemaTest < ActiveRecord::TestCase
|
|||
def teardown
|
||||
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.class_eval do
|
||||
remove_method :execute
|
||||
alias_method :execute, :real_execute
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in a new issue