diff --git a/activesupport/lib/active_support/message_encryptor.rb b/activesupport/lib/active_support/message_encryptor.rb index 634ffbac57..955a7d100a 100644 --- a/activesupport/lib/active_support/message_encryptor.rb +++ b/activesupport/lib/active_support/message_encryptor.rb @@ -121,6 +121,11 @@ module ActiveSupport class InvalidMessage < StandardError; end OpenSSLCipherError = OpenSSL::Cipher::CipherError + AUTH_TAG_LENGTH = 16 # :nodoc: + AUTH_TAG_LENGTH_IN_BASE64 = ((4 * AUTH_TAG_LENGTH / 3) + 3) & ~3 # :nodoc: + SEPARATOR = "--" # :nodoc: + SEPARATOR_LENGTH = SEPARATOR.length # :nodoc: + # Initialize a new MessageEncryptor. +secret+ must be at least as long as # the cipher key size. For the default 'aes-256-gcm' cipher, this is 256 # bits. If you are using a user-entered secret, you can generate a suitable @@ -177,19 +182,25 @@ module ActiveSupport encrypted_data = cipher.update(Messages::Metadata.wrap(@serializer.dump(value), **metadata_options)) encrypted_data << cipher.final - blob = "#{::Base64.strict_encode64 encrypted_data}--#{::Base64.strict_encode64 iv}" - blob = "#{blob}--#{::Base64.strict_encode64 cipher.auth_tag}" if aead_mode? - blob + encoded_encrypted_data = ::Base64.strict_encode64(encrypted_data) + encoded_iv = ::Base64.strict_encode64(iv) + + if aead_mode? + encoded_auth_tag = ::Base64.strict_encode64(cipher.auth_tag(AUTH_TAG_LENGTH)) + "#{encoded_encrypted_data}#{SEPARATOR}#{encoded_iv}#{SEPARATOR}#{encoded_auth_tag}" + else + "#{encoded_encrypted_data}#{SEPARATOR}#{encoded_iv}" + end end def _decrypt(encrypted_message, purpose) cipher = new_cipher - encrypted_data, iv, auth_tag = encrypted_message.split("--").map { |v| ::Base64.strict_decode64(v) } + encrypted_data, iv, auth_tag = get_encrypted_data_and_iv_and_auth_tag_from(encrypted_message) # Currently the OpenSSL bindings do not raise an error if auth_tag is # truncated, which would allow an attacker to easily forge it. See # https://github.com/ruby/openssl/issues/63 - raise InvalidMessage if aead_mode? && (auth_tag.nil? || auth_tag.bytes.length != 16) + raise InvalidMessage if aead_mode? && (auth_tag.nil? || auth_tag.bytes.length != AUTH_TAG_LENGTH) cipher.decrypt cipher.key = @secret @@ -208,6 +219,37 @@ module ActiveSupport raise InvalidMessage end + def iv_length_in_base64 + @iv_length_in_base64 ||= ((4 * new_cipher.iv_len / 3) + 3) & ~3 + end + + def separator_at?(encrypted_message, index) + encrypted_message[index, SEPARATOR_LENGTH] == SEPARATOR + end + + def auth_tag_and_iv_separators_indexes_for(encrypted_message) + if aead_mode? + auth_tag_separator_index = encrypted_message.length - AUTH_TAG_LENGTH_IN_BASE64 - SEPARATOR_LENGTH + return if auth_tag_separator_index < SEPARATOR_LENGTH || !separator_at?(encrypted_message, auth_tag_separator_index) + end + + iv_separator_index = (auth_tag_separator_index || encrypted_message.length) - iv_length_in_base64 - SEPARATOR_LENGTH + return if iv_separator_index.negative? || !separator_at?(encrypted_message, iv_separator_index) + + [auth_tag_separator_index, iv_separator_index] + end + + def get_encrypted_data_and_iv_and_auth_tag_from(encrypted_message) + auth_tag_separator_index, iv_separator_index = auth_tag_and_iv_separators_indexes_for(encrypted_message) + return if iv_separator_index.nil? || (aead_mode? && auth_tag_separator_index.nil?) + + encrypted_data = encrypted_message[0, iv_separator_index] + iv = encrypted_message[iv_separator_index + SEPARATOR_LENGTH, iv_length_in_base64] + auth_tag = encrypted_message[auth_tag_separator_index + SEPARATOR_LENGTH, AUTH_TAG_LENGTH_IN_BASE64] if aead_mode? + + [encrypted_data, iv, auth_tag].map! { |v| ::Base64.strict_decode64(v) if v.present? } + end + def new_cipher OpenSSL::Cipher.new(@cipher) end diff --git a/activesupport/lib/active_support/message_verifier.rb b/activesupport/lib/active_support/message_verifier.rb index c224bdc277..b7ddb4a13a 100644 --- a/activesupport/lib/active_support/message_verifier.rb +++ b/activesupport/lib/active_support/message_verifier.rb @@ -211,9 +211,13 @@ module ActiveSupport @digest_length_in_hex ||= OpenSSL::Digest.new(@digest).digest_length * 2 end + def separator_at?(signed_message, index) + signed_message[index, SEPARATOR_LENGTH] == SEPARATOR + end + def separator_index_for(signed_message) index = signed_message.length - digest_length_in_hex - SEPARATOR_LENGTH - return if index.negative? || signed_message[index, SEPARATOR_LENGTH] != SEPARATOR + return if index.negative? || !separator_at?(signed_message, index) index end @@ -224,8 +228,8 @@ module ActiveSupport separator_index = separator_index_for(signed_message) return if separator_index.nil? - data = signed_message[0...separator_index] - digest = signed_message[separator_index + SEPARATOR_LENGTH..-1] + data = signed_message[0, separator_index] + digest = signed_message[separator_index + SEPARATOR_LENGTH, digest_length_in_hex] [data, digest] end