diff --git a/ext/digest/extconf.rb b/ext/digest/extconf.rb index b5abbfa..bb3f7f2 100644 --- a/ext/digest/extconf.rb +++ b/ext/digest/extconf.rb @@ -1,3 +1,7 @@ require 'mkmf' + +have_header('ruby/digest.h') +have_func('rb_str_set_len') + $CFLAGS << " -fvisibility=hidden" create_makefile('digest/sha3') diff --git a/ext/digest/sha3.c b/ext/digest/sha3.c index 5484d74..6758d04 100644 --- a/ext/digest/sha3.c +++ b/ext/digest/sha3.c @@ -1,137 +1,130 @@ #include "ruby.h" +#ifdef HAVE_RUBY_DIGEST_H +#include "ruby/digest.h" +#else +#include "digest.h" +#endif #include "KeccakNISTInterface.h" #define MAX_DIGEST_SIZE 64 +#define DEFAULT_DIGEST_LEN 512 -static VALUE mDigest, cSHA3; +static void sha3_init_func(hashState *ctx); +static void sha3_update_func(hashState *ctx, unsigned char *str, size_t len); -typedef struct { - hashState state; - int bitlen; -} RbSHA3; +static rb_digest_metadata_t sha3 = { + RUBY_DIGEST_API_VERSION, + DEFAULT_DIGEST_LEN, + KeccakPermutationSize - (2 * DEFAULT_DIGEST_LEN), + sizeof(hashState), + (rb_digest_hash_init_func_t)sha3_init_func, + (rb_digest_hash_update_func_t)sha3_update_func, + NULL, +}; + +static void +sha3_init(hashState *ctx, size_t bitlen) { + switch (Init(ctx, bitlen)) { + case SUCCESS: + return; + case FAIL: + rb_raise(rb_eRuntimeError, "Unknown error"); + case BAD_HASHLEN: + rb_raise(rb_eRuntimeError, "Bad hash length (must be 0, 224, 256, 384 or 512)"); + default: + rb_raise(rb_eRuntimeError, "Unknown error code"); + } +} + +static void +sha3_init_func(hashState *ctx) { + Init(ctx, ctx->capacity / 2); +} + +static void +sha3_update_func(hashState *ctx, unsigned char *str, size_t len) { + Update(ctx, str, len * 8); +} static VALUE rb_sha3_alloc(VALUE klass) { - RbSHA3 *ctx; - - ctx = (RbSHA3 *) xmalloc(sizeof(RbSHA3)); - ctx->bitlen = -1; + hashState *ctx; + + ctx = (hashState *) xmalloc(sizeof(hashState)); + sha3_init(ctx, DEFAULT_DIGEST_LEN); return Data_Wrap_Struct(klass, 0, xfree, ctx); } static VALUE rb_sha3_initialize(int argc, VALUE *argv, VALUE self) { - RbSHA3 *ctx; + hashState *ctx; VALUE hashlen; int i_hashlen; if (rb_scan_args(argc, argv, "01", &hashlen) == 0) { - i_hashlen = 512; + i_hashlen = DEFAULT_DIGEST_LEN; } else { i_hashlen = NUM2INT(hashlen); } - if (i_hashlen == 0) { + switch (i_hashlen) { + case 0: rb_raise(rb_eRuntimeError, "Unsupported hash length"); - } - - Data_Get_Struct(self, RbSHA3, ctx); - ctx->bitlen = i_hashlen; - - switch (Init(&ctx->state, i_hashlen)) { - case SUCCESS: - return self; - case FAIL: - rb_raise(rb_eRuntimeError, "Unknown error"); - return Qnil; - case BAD_HASHLEN: - rb_raise(rb_eRuntimeError, "Bad hash length (must be 0, 224, 256, 384 or 512)"); - return Qnil; + case DEFAULT_DIGEST_LEN: + break; default: - rb_raise(rb_eRuntimeError, "Unknown error code"); - return Qnil; + Data_Get_Struct(self, hashState, ctx); + sha3_init(ctx, i_hashlen); } -} -static VALUE -rb_sha3_initialize_copy(VALUE self, VALUE other) { - RbSHA3 *ctx_self, *ctx_other; - - rb_check_frozen(self); - Data_Get_Struct(self, RbSHA3, ctx_self); - Data_Get_Struct(other, RbSHA3, ctx_other); - memcpy(&ctx_self->state, &ctx_other->state, sizeof(hashState)); - ctx_self->bitlen = ctx_other->bitlen; return self; } static VALUE -rb_sha3_reset(VALUE self) { - RbSHA3 *ctx; +rb_sha3_finish(VALUE self) { + hashState *ctx; + VALUE digest; - Data_Get_Struct(self, RbSHA3, ctx); - Init(&ctx->state, ctx->bitlen); - return self; + Data_Get_Struct(self, hashState, ctx); + + digest = rb_str_new(0, ctx->capacity / 2 / 8); + + Final(ctx, (unsigned char *)RSTRING_PTR(digest)); + + return digest; } static VALUE -rb_sha3_update(VALUE self, VALUE str) { - RbSHA3 *ctx; +rb_sha3_digest_length(VALUE self) { + hashState *ctx; - Data_Get_Struct(self, RbSHA3, ctx); - Update(&ctx->state, RSTRING_PTR(str), RSTRING_LEN(str) * 8); - return self; + Data_Get_Struct(self, hashState, ctx); + return INT2FIX(ctx->capacity / 2 / 8); } static VALUE -rb_sha3_digest(VALUE self, VALUE str) { - RbSHA3 *ctx; - hashState state; - unsigned char digest[MAX_DIGEST_SIZE]; +rb_sha3_block_length(VALUE self) { + hashState *ctx; - Data_Get_Struct(self, RbSHA3, ctx); - memcpy(&state, &ctx->state, sizeof(hashState)); - Final(&state, digest); - return rb_str_new((const char *) digest, ctx->bitlen / 8); -} - -static VALUE -rb_sha3_singleton_digest(int argc, VALUE *argv, VALUE klass) { - VALUE data, hashlen; - int i_hashlen; - unsigned char digest[MAX_DIGEST_SIZE]; - - if (rb_scan_args(argc, argv, "11", &data, &hashlen) == 1) { - i_hashlen = 512; - } else { - i_hashlen = NUM2INT(hashlen); - } - - switch (Hash(i_hashlen, RSTRING_PTR(data), RSTRING_LEN(data) * 8, digest)) { - case SUCCESS: - return rb_str_new(digest, i_hashlen / 8); - case FAIL: - rb_raise(rb_eRuntimeError, "Unknown error"); - return Qnil; - case BAD_HASHLEN: - rb_raise(rb_eRuntimeError, "Bad hash length (must be 0, 224, 256, 384 or 512)"); - return Qnil; - default: - rb_raise(rb_eRuntimeError, "Unknown error code"); - return Qnil; - } + Data_Get_Struct(self, hashState, ctx); + return INT2FIX(ctx->rate / 8); } void __attribute__((visibility("default"))) Init_sha3() { - mDigest = rb_define_module("Digest"); - cSHA3 = rb_define_class_under(mDigest, "SHA3", rb_cObject); + VALUE mDigest, cDigest_Base, cSHA3; + + rb_require("digest"); + + mDigest = rb_path2class("Digest"); + cDigest_Base = rb_path2class("Digest::Base"); + + cSHA3 = rb_define_class_under(mDigest, "SHA3", cDigest_Base); + + rb_ivar_set(cSHA3, rb_intern("metadata"), Data_Wrap_Struct(rb_cObject, 0, 0, &sha3)); + rb_define_alloc_func(cSHA3, rb_sha3_alloc); rb_define_method(cSHA3, "initialize", rb_sha3_initialize, -1); - rb_define_method(cSHA3, "initialize_copy", rb_sha3_initialize_copy, 1); - rb_define_method(cSHA3, "reset", rb_sha3_reset, 0); - rb_define_method(cSHA3, "update", rb_sha3_update, 1); - rb_define_method(cSHA3, "<<", rb_sha3_update, 1); - rb_define_method(cSHA3, "digest", rb_sha3_digest, 0); - rb_define_singleton_method(cSHA3, "digest", rb_sha3_singleton_digest, -1); - rb_require("digest/sha3/helpers"); + rb_define_private_method(cSHA3, "finish", rb_sha3_finish, 0); + rb_define_method(cSHA3, "digest_length", rb_sha3_digest_length, 0); + rb_define_method(cSHA3, "block_length", rb_sha3_block_length, 0); } diff --git a/lib/digest/sha3/helpers.rb b/lib/digest/sha3/helpers.rb deleted file mode 100644 index d9349ac..0000000 --- a/lib/digest/sha3/helpers.rb +++ /dev/null @@ -1,20 +0,0 @@ -# encoding: ascii -Digest::SHA3.class_eval do - def self.hexdigest(*args) - force_ascii(digest(*args).unpack("H*").first) - end - - def hexdigest - Digest::SHA3.force_ascii(digest.unpack("H*").first) - end - - if ''.respond_to?(:force_encoding) - def self.force_ascii(str) - str.force_encoding('ascii') - end - else - def self.force_ascii(str) - str - end - end -end