From 417f5b8f70eb9f8e7d679462238b21f92edcb861 Mon Sep 17 00:00:00 2001 From: Peter Karman Date: Tue, 11 Oct 2016 12:04:48 -0500 Subject: [PATCH] Add RSA sign_pss() and verify_pss() methods Support Probabilistic Signature Scheme for RSA key signing. [ky: the patch was originally submitted as GitHub Pull Request #76. finish keyword arguments handling, update docs, and fix tests.] --- ext/openssl/ossl_pkey_rsa.c | 192 ++++++++++++++++++++++++++++++++++++ test/test_pkey_rsa.rb | 33 +++++++ 2 files changed, 225 insertions(+) diff --git a/ext/openssl/ossl_pkey_rsa.c b/ext/openssl/ossl_pkey_rsa.c index 26397bd02..4800fb271 100644 --- a/ext/openssl/ossl_pkey_rsa.c +++ b/ext/openssl/ossl_pkey_rsa.c @@ -536,6 +536,196 @@ ossl_rsa_private_decrypt(int argc, VALUE *argv, VALUE self) return str; } +/* + * call-seq: + * rsa.sign_pss(digest, data, salt_length:, mgf1_hash:) -> String + * + * Signs _data_ using the Probabilistic Signature Scheme (RSA-PSS) and returns + * the calculated signature. + * + * RSAError will be raised if an error occurs. + * + * See #verify_pss for the verification operation. + * + * === Parameters + * _digest_:: + * A String containing the message digest algorithm name. + * _data_:: + * A String. The data to be signed. + * _salt_length_:: + * The length in octets of the salt. Two special values are reserved: + * +:digest+ means the digest length, and +:max+ means the maximum possible + * length for the combination of the private key and the selected message + * digest algorithm. + * _mgf1_hash_:: + * The hash algorithm used in MGF1 (the currently supported mask generation + * function (MGF)). + * + * === Example + * data = "Sign me!" + * pkey = OpenSSL::PKey::RSA.new(2048) + * signature = pkey.sign_pss("SHA256", data, salt_length: :max, mgf1_hash: "SHA256") + * pub_key = pkey.public_key + * puts pub_key.verify_pss("SHA256", signature, data, + * salt_length: :auto, mgf1_hash: "SHA256") # => true + */ +static VALUE +ossl_rsa_sign_pss(int argc, VALUE *argv, VALUE self) +{ + VALUE digest, data, options, kwargs[2], signature; + static ID kwargs_ids[2]; + EVP_PKEY *pkey; + EVP_PKEY_CTX *pkey_ctx; + const EVP_MD *md, *mgf1md; + EVP_MD_CTX *md_ctx; + size_t buf_len; + int salt_len; + + if (!kwargs_ids[0]) { + kwargs_ids[0] = rb_intern_const("salt_length"); + kwargs_ids[1] = rb_intern_const("mgf1_hash"); + } + rb_scan_args(argc, argv, "2:", &digest, &data, &options); + rb_get_kwargs(options, kwargs_ids, 2, 0, kwargs); + if (kwargs[0] == ID2SYM(rb_intern("max"))) + salt_len = -2; /* RSA_PSS_SALTLEN_MAX_SIGN */ + else if (kwargs[0] == ID2SYM(rb_intern("digest"))) + salt_len = -1; /* RSA_PSS_SALTLEN_DIGEST */ + else + salt_len = NUM2INT(kwargs[0]); + mgf1md = ossl_evp_get_digestbyname(kwargs[1]); + + pkey = GetPrivPKeyPtr(self); + buf_len = EVP_PKEY_size(pkey); + md = ossl_evp_get_digestbyname(digest); + StringValue(data); + signature = rb_str_new(NULL, (long)buf_len); + + md_ctx = EVP_MD_CTX_new(); + if (!md_ctx) + goto err; + + if (EVP_DigestSignInit(md_ctx, &pkey_ctx, md, NULL, pkey) != 1) + goto err; + + if (EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING) != 1) + goto err; + + if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, salt_len) != 1) + goto err; + + if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1md) != 1) + goto err; + + if (EVP_DigestSignUpdate(md_ctx, RSTRING_PTR(data), RSTRING_LEN(data)) != 1) + goto err; + + if (EVP_DigestSignFinal(md_ctx, (unsigned char *)RSTRING_PTR(signature), &buf_len) != 1) + goto err; + + rb_str_set_len(signature, (long)buf_len); + + EVP_MD_CTX_free(md_ctx); + return signature; + + err: + EVP_MD_CTX_free(md_ctx); + ossl_raise(eRSAError, NULL); +} + +/* + * call-seq: + * rsa.verify_pss(digest, signature, data, salt_length:, mgf1_hash:) -> true | false + * + * Verifies _data_ using the Probabilistic Signature Scheme (RSA-PSS). + * + * The return value is +true+ if the signature is valid, +false+ otherwise. + * RSAError will be raised if an error occurs. + * + * See #sign_pss for the signing operation and an example code. + * + * === Parameters + * _digest_:: + * A String containing the message digest algorithm name. + * _data_:: + * A String. The data to be signed. + * _salt_length_:: + * The length in octets of the salt. Two special values are reserved: + * +:digest+ means the digest length, and +:auto+ means automatically + * determining the length based on the signature. + * _mgf1_hash_:: + * The hash algorithm used in MGF1. + */ +static VALUE +ossl_rsa_verify_pss(int argc, VALUE *argv, VALUE self) +{ + VALUE digest, signature, data, options, kwargs[2]; + static ID kwargs_ids[2]; + EVP_PKEY *pkey; + EVP_PKEY_CTX *pkey_ctx; + const EVP_MD *md, *mgf1md; + EVP_MD_CTX *md_ctx; + int result, salt_len; + + if (!kwargs_ids[0]) { + kwargs_ids[0] = rb_intern_const("salt_length"); + kwargs_ids[1] = rb_intern_const("mgf1_hash"); + } + rb_scan_args(argc, argv, "3:", &digest, &signature, &data, &options); + rb_get_kwargs(options, kwargs_ids, 2, 0, kwargs); + if (kwargs[0] == ID2SYM(rb_intern("auto"))) + salt_len = -2; /* RSA_PSS_SALTLEN_AUTO */ + else if (kwargs[0] == ID2SYM(rb_intern("digest"))) + salt_len = -1; /* RSA_PSS_SALTLEN_DIGEST */ + else + salt_len = NUM2INT(kwargs[0]); + mgf1md = ossl_evp_get_digestbyname(kwargs[1]); + + GetPKey(self, pkey); + md = ossl_evp_get_digestbyname(digest); + StringValue(signature); + StringValue(data); + + md_ctx = EVP_MD_CTX_new(); + if (!md_ctx) + goto err; + + if (EVP_DigestVerifyInit(md_ctx, &pkey_ctx, md, NULL, pkey) != 1) + goto err; + + if (EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING) != 1) + goto err; + + if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, salt_len) != 1) + goto err; + + if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1md) != 1) + goto err; + + if (EVP_DigestVerifyUpdate(md_ctx, RSTRING_PTR(data), RSTRING_LEN(data)) != 1) + goto err; + + result = EVP_DigestVerifyFinal(md_ctx, + (unsigned char *)RSTRING_PTR(signature), + RSTRING_LEN(signature)); + + switch (result) { + case 0: + ossl_clear_error(); + EVP_MD_CTX_free(md_ctx); + return Qfalse; + case 1: + EVP_MD_CTX_free(md_ctx); + return Qtrue; + default: + goto err; + } + + err: + EVP_MD_CTX_free(md_ctx); + ossl_raise(eRSAError, NULL); +} + /* * call-seq: * rsa.params => hash @@ -731,6 +921,8 @@ Init_ossl_rsa(void) rb_define_method(cRSA, "public_decrypt", ossl_rsa_public_decrypt, -1); rb_define_method(cRSA, "private_encrypt", ossl_rsa_private_encrypt, -1); rb_define_method(cRSA, "private_decrypt", ossl_rsa_private_decrypt, -1); + rb_define_method(cRSA, "sign_pss", ossl_rsa_sign_pss, -1); + rb_define_method(cRSA, "verify_pss", ossl_rsa_verify_pss, -1); DEF_OSSL_PKEY_BN(cRSA, rsa, n); DEF_OSSL_PKEY_BN(cRSA, rsa, e); diff --git a/test/test_pkey_rsa.rb b/test/test_pkey_rsa.rb index 49ab37925..d9bea1a62 100644 --- a/test/test_pkey_rsa.rb +++ b/test/test_pkey_rsa.rb @@ -113,6 +113,39 @@ def test_verify_empty_rsa } end + def test_sign_verify_pss + key = Fixtures.pkey("rsa1024") + data = "Sign me!" + invalid_data = "Sign me?" + + signature = key.sign_pss("SHA256", data, salt_length: 20, mgf1_hash: "SHA1") + assert_equal 128, signature.bytesize + assert_equal true, + key.verify_pss("SHA256", signature, data, salt_length: 20, mgf1_hash: "SHA1") + assert_equal true, + key.verify_pss("SHA256", signature, data, salt_length: :auto, mgf1_hash: "SHA1") + assert_equal false, + key.verify_pss("SHA256", signature, invalid_data, salt_length: 20, mgf1_hash: "SHA1") + + signature = key.sign_pss("SHA256", data, salt_length: :digest, mgf1_hash: "SHA1") + assert_equal true, + key.verify_pss("SHA256", signature, data, salt_length: 32, mgf1_hash: "SHA1") + assert_equal true, + key.verify_pss("SHA256", signature, data, salt_length: :auto, mgf1_hash: "SHA1") + assert_equal false, + key.verify_pss("SHA256", signature, data, salt_length: 20, mgf1_hash: "SHA1") + + signature = key.sign_pss("SHA256", data, salt_length: :max, mgf1_hash: "SHA1") + assert_equal true, + key.verify_pss("SHA256", signature, data, salt_length: 94, mgf1_hash: "SHA1") + assert_equal true, + key.verify_pss("SHA256", signature, data, salt_length: :auto, mgf1_hash: "SHA1") + + assert_raise(OpenSSL::PKey::RSAError) { + key.sign_pss("SHA256", data, salt_length: 95, mgf1_hash: "SHA1") + } + end + def test_RSAPrivateKey rsa1024 = Fixtures.pkey("rsa1024") asn1 = OpenSSL::ASN1::Sequence([