diff --git a/.travis.yml b/.travis.yml index bed83e9..ef3201c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,16 +1,22 @@ -sudo: required -dist: trusty - +dist: xenial language: go go: - - "1.10.x" + - "1.11.x" - "master" -# trusty only has softhsmv1 -before_script: - - sudo add-apt-repository -y ppa:pkg-opendnssec/ppa && sudo apt-get update && sudo apt-get install softhsm2 - - curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh +# Xenial comes with v2.0.0 SoftHSM2, which seems to have issues with ECDSA +# code points +addons: + apt: + sources: + - sourceline: 'ppa:pkg-opendnssec/ppa' + packages: + - softhsm2 + +env: + - GO111MODULE=on + script: - echo directories.tokendir = `pwd`/tokens > softhsm2.conf @@ -18,7 +24,5 @@ script: - cat softhsm2.conf - mkdir tokens - export SOFTHSM2_CONF=`pwd`/softhsm2.conf - - softhsm2-util --init-token --slot 0 --label test --so-pin sopassword --pin password - - cp configs/config.softhsm2 config - - dep ensure - - go test -v -bench . + - softhsm2-util --init-token --slot 0 --label token1 --so-pin sopassword --pin password + - go test -mod readonly -v -bench . diff --git a/Gopkg.lock b/Gopkg.lock deleted file mode 100644 index 1e78dac..0000000 --- a/Gopkg.lock +++ /dev/null @@ -1,86 +0,0 @@ -# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. - - -[[projects]] - digest = "1:0deddd908b6b4b768cfc272c16ee61e7088a60f7fe2f06c547bd3d8e1f8b8e77" - name = "github.com/davecgh/go-spew" - packages = ["spew"] - pruneopts = "" - revision = "8991bc29aa16c548c550c7ff78260e27b9ab7c73" - version = "v1.1.1" - -[[projects]] - branch = "master" - digest = "1:e48e3de0a7d38e6d55730d0ade3624650ccdbfcc8f518db3486401337cac617b" - name = "github.com/miekg/pkcs11" - packages = ["."] - pruneopts = "" - revision = "c6d6ee821fb161c8022ceb8ba93ce3b815d8d62e" - -[[projects]] - digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" - name = "github.com/pmezard/go-difflib" - packages = ["difflib"] - pruneopts = "" - revision = "792786c7400a136282c1664665ae0a8db921c6c2" - version = "v1.0.0" - -[[projects]] - digest = "1:711eebe744c0151a9d09af2315f0bb729b2ec7637ef4c410fa90a18ef74b65b6" - name = "github.com/stretchr/objx" - packages = ["."] - pruneopts = "" - revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" - version = "v0.1.1" - -[[projects]] - digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c" - name = "github.com/stretchr/testify" - packages = [ - ".", - "assert", - "http", - "mock", - ] - pruneopts = "" - revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053" - version = "v1.3.0" - -[[projects]] - digest = "1:c04f8425afbb4fe70847f832fb4f773db79f9b0460bdbdd112282a369a401638" - name = "github.com/youtube/vitess" - packages = ["go/pools"] - pruneopts = "" - revision = "66e84fadcc1a7e956e7ffcebcaaba0b04132ca1f" - version = "v2.2" - -[[projects]] - branch = "master" - digest = "1:08e41d63f8dac84d83797368b56cf0b339e42d0224e5e56668963c28aec95685" - name = "golang.org/x/net" - packages = ["context"] - pruneopts = "" - revision = "4dfa2610cdf3b287375bbba5b8f2a14d3b01d8de" - -[[projects]] - digest = "1:c04f8425afbb4fe70847f832fb4f773db79f9b0460bdbdd112282a369a401638" - name = "vitess.io/vitess" - packages = [ - "go/cache", - "go/sync2", - "go/timer", - ] - pruneopts = "" - revision = "66e84fadcc1a7e956e7ffcebcaaba0b04132ca1f" - version = "v2.2" - -[solve-meta] - analyzer-name = "dep" - analyzer-version = 1 - input-imports = [ - "github.com/miekg/pkcs11", - "github.com/stretchr/testify", - "github.com/youtube/vitess/go/pools", - ] - solver-name = "gps-cdcl" - solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml deleted file mode 100644 index d604a73..0000000 --- a/Gopkg.toml +++ /dev/null @@ -1,30 +0,0 @@ - -# Gopkg.toml example -# -# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md -# for detailed Gopkg.toml documentation. -# -# required = ["github.com/user/thing/cmd/thing"] -# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] -# -# [[constraint]] -# name = "github.com/user/project" -# version = "1.0.0" -# -# [[constraint]] -# name = "github.com/user/project2" -# branch = "dev" -# source = "github.com/myfork/project2" -# -# [[override]] -# name = "github.com/x/y" -# version = "2.4.0" - - -[[constraint]] - branch = "master" - name = "github.com/miekg/pkcs11" - -[[constraint]] - name = "github.com/stretchr/testify" - version = "1.3.0" diff --git a/README.md b/README.md index 6add6df..2923af6 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Crypto11 [![GoDoc](https://godoc.org/github.com/ThalesIgnite/crypto11?status.svg)](https://godoc.org/github.com/ThalesIgnite/crypto11) [![Build Status](https://travis-ci.com/ThalesIgnite/crypto11.svg?branch=master)](https://travis-ci.com/ThalesIgnite/crypto11) -This is an implementation of the standard Golang hardware crypto interface that +This is an implementation of the standard Golang crypto interfaces that uses [PKCS#11](http://docs.oasis-open.org/pkcs11/pkcs11-base/v2.40/errata01/os/pkcs11-base-v2.40-errata01-os-complete.html) as a backend. The supported features are: * Generation and retrieval of RSA, DSA and ECDSA keys. @@ -15,8 +15,8 @@ uses [PKCS#11](http://docs.oasis-open.org/pkcs11/pkcs11-base/v2.40/errata01/os/p * ECDSA signing. * DSA signing. * Random number generation. -* (Experimental) AES and DES3 encryption and decryption. -* (Experimental) HMAC support. +* AES and DES3 encryption and decryption. +* HMAC support. Signing is done through the [crypto.Signer](https://golang.org/pkg/crypto/#Signer) interface and @@ -25,37 +25,72 @@ decryption through To verify signatures or encrypt messages, retrieve the public key and do it in software. -See the documentation for details of various limitations. +See [the documentation](https://godoc.org/github.com/ThalesIgnite/crypto11) for details of various limitations, +especially regarding symmetric crypto. -There are some rudimentary tests. - -There is a demo web server in the `demo` directory, which publishes -the contents of `/usr/share/doc`. Installation ============ -(If you don't have one already) create [a standard Go workspace](https://golang.org/doc/code.html#Workspaces) and set the `GOPATH` environment variable to point to the workspace root. - -crypto11 manages it's dependencies via `dep`. To Install `dep` run: +Since v1.0.0, crypto11 requires Go v1.11+. Install the library by running: - go get -u github.com/golang/dep/cmd/dep +```bash +go get github.com/ThalesIgnite/crypto11 +``` -Clone, ensure deps, and build: +The crypto11 library needs to be configured with information about your PKCS#11 installation. This is either done programmatically +(see the `Config` struct in [the documentation](https://godoc.org/github.com/ThalesIgnite/crypto11)) or via a configuration +file. The configuration file is a JSON representation of the `Config` struct. - go get github.com/ThalesIgnite/crypto11 - cd $GOPATH/src/github.com/ThalesIgnite/crypto11 - dep ensure - go build +A minimal configuration file looks like this: -Edit `config` to taste, and then run the test program: +```json +{ + "Path" : "/usr/lib/softhsm/libsofthsm2.so", + "TokenLabel": "token1", + "Pin" : "password" +} +``` - go test -count=1 +- `Path` points to the library from your PKCS#11 vendor. +- `TokenLabel` is the `CKA_LABEL` of the token you wish to use. +- `Pin` is the password for the `CKU_USER` user. Testing Guidance ================ -Testing with nShield +Testing with SoftHSM2 +--------------------- + +To set up a slot: + + $ cat softhsm2.conf + directories.tokendir = /home/rjk/go/src/github.com/ThalesIgnite/crypto11/tokens + objectstore.backend = file + log.level = INFO + $ mkdir tokens + $ export SOFTHSM2_CONF=`pwd`/softhsm2.conf + $ softhsm2-util --init-token --slot 0 --label test + === SO PIN (4-255 characters) === + Please enter SO PIN: ******** + Please reenter SO PIN: ******** + === User PIN (4-255 characters) === + Please enter user PIN: ******** + Please reenter user PIN: ******** + The token has been initialized. + +The configuration looks like this: + + $ cat config + { + "Path" : "/usr/lib/softhsm/libsofthsm2.so", + "TokenLabel": "test", + "Pin" : "password" + } + +(At time of writing) OAEP is only partial and HMAC is unsupported, so expect test skips. + +Testing with nCipher nShield -------------------- In all cases, it's worth enabling nShield PKCS#11 log output: @@ -94,66 +129,6 @@ To protect keys with the module only, use the 'accelerator' token: (At time of writing) GCM is not implemented, so expect test skips. -Testing with SoftHSM --------------------- - -While the aim of the exercise is to use an HSM, it can be convenient -to test with a software-only provider. - -To set up a slot: - - $ cat softhsm.conf - 0:softhsm0.db - $ export SOFTHSM_CONF=`pwd`/softhsm.conf - $ softhsm --init-token --slot 0 --label test - The SO PIN must have a length between 4 and 255 characters. - Enter SO PIN: - The user PIN must have a length between 4 and 255 characters. - Enter user PIN: - The token has been initialized. - -Configure as follows: - - $ cat config - { - "Path" : "/usr/lib/softhsm/libsofthsm.so", - "TokenLabel": "test", - "Pin" : "password" - } - -DSA, ECDSA, PSS and OAEP aren't supported, so expect test failures. - -Testing with SoftHSM2 ---------------------- - -To set up a slot: - - $ cat softhsm2.conf - directories.tokendir = /home/rjk/go/src/github.com/ThalesIgnite/crypto11/tokens - objectstore.backend = file - log.level = INFO - $ mkdir tokens - $ export SOFTHSM2_CONF=`pwd`/softhsm2.conf - $ softhsm2-util --init-token --slot 0 --label test - === SO PIN (4-255 characters) === - Please enter SO PIN: ******** - Please reenter SO PIN: ******** - === User PIN (4-255 characters) === - Please enter user PIN: ******** - Please reenter user PIN: ******** - The token has been initialized. - -The configuration looks like this: - - $ cat config - { - "Path" : "/usr/lib/softhsm/libsofthsm2.so", - "TokenLabel": "test", - "Pin" : "password" - } - -(At time of writing) OAEP is only partial and HMAC is unsupported, so expect test skips. - Limitations =========== @@ -168,34 +143,13 @@ but you must call the Close() interface (not found in [cipher.BlockMode](https://golang.org/pkg/crypto/cipher/#BlockMode)). See [issue #6](https://github.com/ThalesIgnite/crypto11/issues/6) for further discussion. -Wishlist +Contributions ======== -* Full test instructions for additional PKCS#11 implementations. -* A pony. - -Copyright -========= - -MIT License. - -Copyright 2016-2018 Thales e-Security, Inc +Contributions are gratefully received. Before beginning work on sizeable changes, please open an issue first to +discuss. -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: +Here are some topics we'd like to cover: -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +* Full test instructions for additional PKCS#11 implementations. +* Move to another resource pool implementation (`github.com/vitessio/vitess` is a big dependency) diff --git a/aead.go b/aead.go index 490fb07..c0c720c 100644 --- a/aead.go +++ b/aead.go @@ -25,21 +25,25 @@ import ( "crypto/cipher" "errors" "fmt" + "github.com/miekg/pkcs11" ) // cipher.AEAD ---------------------------------------------------------- +// A PaddingMode is used by a block cipher (see NewCBC). +type PaddingMode int + const ( - // PaddingNone represents a block cipher with no padding. (See NewCBC.) - PaddingNone = iota + // PaddingNone represents a block cipher with no padding. + PaddingNone PaddingMode = iota - // PaddingPKCS represents a block cipher used with PKCS#7 padding. (See NewCBC.) + // PaddingPKCS represents a block cipher used with PKCS#7 padding. PaddingPKCS ) type genericAead struct { - key *PKCS11SecretKey + key *SecretKey overhead int @@ -53,22 +57,21 @@ type genericAead struct { // // This depends on the HSM supporting the CKM_*_GCM mechanism. If it is not supported // then you must use cipher.NewGCM; it will be slow. -func (key *PKCS11SecretKey) NewGCM() (g cipher.AEAD, err error) { +func (key *SecretKey) NewGCM() (cipher.AEAD, error) { if key.Cipher.GCMMech == 0 { - err = fmt.Errorf("GCM not implemented for key type %#x", key.Cipher.GenParams[0].KeyType) - return + return nil, fmt.Errorf("GCM not implemented for key type %#x", key.Cipher.GenParams[0].KeyType) } - g = genericAead{ + + g := genericAead{ key: key, overhead: 16, nonceSize: 12, - makeMech: func(nonce []byte, additionalData []byte) (mech []*pkcs11.Mechanism, error error) { + makeMech: func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, error) { params := pkcs11.NewGCMParams(nonce, additionalData, 16*8 /*bits*/) - mech = []*pkcs11.Mechanism{pkcs11.NewMechanism(key.Cipher.GCMMech, params)} - return + return []*pkcs11.Mechanism{pkcs11.NewMechanism(key.Cipher.GCMMech, params)}, nil }, } - return + return g, nil } // NewCBC returns a given cipher wrapped in CBC mode. @@ -76,34 +79,33 @@ func (key *PKCS11SecretKey) NewGCM() (g cipher.AEAD, err error) { // Despite the cipher.AEAD return type, there is no support for additional data and no authentication. // This method exists to provide a convenient way to do bulk (possibly padded) CBC encryption. // Think carefully before passing the cipher.AEAD to any consumer that expects authentication. -func (key *PKCS11SecretKey) NewCBC(paddingMode int) (g cipher.AEAD, err error) { - g = genericAead{ +func (key *SecretKey) NewCBC(paddingMode PaddingMode) (cipher.AEAD, error) { + + var pkcsMech uint + + switch paddingMode { + case PaddingNone: + pkcsMech = key.Cipher.CBCMech + case PaddingPKCS: + pkcsMech = key.Cipher.CBCPKCSMech + default: + return nil, errors.New("unrecognized padding mode") + } + + g := genericAead{ key: key, overhead: 0, nonceSize: key.BlockSize(), - makeMech: func(nonce []byte, additionalData []byte) (mech []*pkcs11.Mechanism, error error) { + makeMech: func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, error) { if len(additionalData) > 0 { - err = errors.New("additional data not supported for CBC mode") - } - var pkcsMech uint - switch paddingMode { - case PaddingNone: - pkcsMech = key.Cipher.CBCMech - case PaddingPKCS: - pkcsMech = key.Cipher.CBCPKCSMech - default: - err = errors.New("unrecognized padding mode") - return - } - if pkcsMech == 0 { - err = errors.New("unsupported padding mode") - return + return nil, errors.New("additional data not supported for CBC mode") } - mech = []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcsMech, nonce)} - return + + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcsMech, nonce)}, nil }, } - return + + return g, nil } func (g genericAead) NonceSize() int { @@ -116,16 +118,16 @@ func (g genericAead) Overhead() int { func (g genericAead) Seal(dst, nonce, plaintext, additionalData []byte) []byte { var result []byte - if err := withSession(g.key.Slot, func(session *PKCS11Session) (err error) { + if err := g.key.context.withSession(func(session *pkcs11Session) (err error) { var mech []*pkcs11.Mechanism if mech, err = g.makeMech(nonce, additionalData); err != nil { return } - if err = session.Ctx.EncryptInit(session.Handle, mech, g.key.Handle); err != nil { + if err = session.ctx.EncryptInit(session.handle, mech, g.key.handle); err != nil { err = fmt.Errorf("C_EncryptInit: %v", err) return } - if result, err = session.Ctx.Encrypt(session.Handle, plaintext); err != nil { + if result, err = session.ctx.Encrypt(session.handle, plaintext); err != nil { err = fmt.Errorf("C_Encrypt: %v", err) return } @@ -140,16 +142,16 @@ func (g genericAead) Seal(dst, nonce, plaintext, additionalData []byte) []byte { func (g genericAead) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { var result []byte - if err := withSession(g.key.Slot, func(session *PKCS11Session) (err error) { + if err := g.key.context.withSession(func(session *pkcs11Session) (err error) { var mech []*pkcs11.Mechanism if mech, err = g.makeMech(nonce, additionalData); err != nil { return } - if err = session.Ctx.DecryptInit(session.Handle, mech, g.key.Handle); err != nil { + if err = session.ctx.DecryptInit(session.handle, mech, g.key.handle); err != nil { err = fmt.Errorf("C_DecryptInit: %v", err) return } - if result, err = session.Ctx.Decrypt(session.Handle, ciphertext); err != nil { + if result, err = session.ctx.Decrypt(session.handle, ciphertext); err != nil { err = fmt.Errorf("C_Decrypt: %v", err) return } diff --git a/block.go b/block.go index d3af6fd..6c612c9 100644 --- a/block.go +++ b/block.go @@ -23,13 +23,14 @@ package crypto11 import ( "fmt" + "github.com/miekg/pkcs11" ) // cipher.Block --------------------------------------------------------- // BlockSize returns the cipher's block size in bytes. -func (key *PKCS11SecretKey) BlockSize() int { +func (key *SecretKey) BlockSize() int { return key.Cipher.BlockSize } @@ -39,14 +40,14 @@ func (key *PKCS11SecretKey) BlockSize() int { // Using this method for bulk operation is very inefficient, as it makes a round trip to the HSM // (which may be network-connected) for each block. // For more efficient operation, see NewCBCDecrypterCloser, NewCBCDecrypter or NewCBC. -func (key *PKCS11SecretKey) Decrypt(dst, src []byte) { +func (key *SecretKey) Decrypt(dst, src []byte) { var result []byte - if err := withSession(key.Slot, func(session *PKCS11Session) (err error) { + if err := key.context.withSession(func(session *pkcs11Session) (err error) { mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(key.Cipher.ECBMech, nil)} - if err = session.Ctx.DecryptInit(session.Handle, mech, key.Handle); err != nil { + if err = session.ctx.DecryptInit(session.handle, mech, key.handle); err != nil { return } - if result, err = session.Ctx.Decrypt(session.Handle, src[:key.Cipher.BlockSize]); err != nil { + if result, err = session.ctx.Decrypt(session.handle, src[:key.Cipher.BlockSize]); err != nil { return } if len(result) != key.Cipher.BlockSize { @@ -67,14 +68,14 @@ func (key *PKCS11SecretKey) Decrypt(dst, src []byte) { // Using this method for bulk operation is very inefficient, as it makes a round trip to the HSM // (which may be network-connected) for each block. // For more efficient operation, see NewCBCEncrypterCloser, NewCBCEncrypter or NewCBC. -func (key *PKCS11SecretKey) Encrypt(dst, src []byte) { +func (key *SecretKey) Encrypt(dst, src []byte) { var result []byte - if err := withSession(key.Slot, func(session *PKCS11Session) (err error) { + if err := key.context.withSession(func(session *pkcs11Session) (err error) { mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(key.Cipher.ECBMech, nil)} - if err = session.Ctx.EncryptInit(session.Handle, mech, key.Handle); err != nil { + if err = session.ctx.EncryptInit(session.handle, mech, key.handle); err != nil { return } - if result, err = session.Ctx.Encrypt(session.Handle, src[:key.Cipher.BlockSize]); err != nil { + if result, err = session.ctx.Encrypt(session.handle, src[:key.Cipher.BlockSize]); err != nil { return } if len(result) != key.Cipher.BlockSize { diff --git a/blockmode.go b/blockmode.go index 182c2d6..f9cea6a 100644 --- a/blockmode.go +++ b/blockmode.go @@ -22,12 +22,10 @@ package crypto11 import ( - "context" "crypto/cipher" - "fmt" - "github.com/miekg/pkcs11" - "github.com/youtube/vitess/go/pools" "runtime" + + "github.com/miekg/pkcs11" ) // cipher.BlockMode ----------------------------------------------------- @@ -57,7 +55,7 @@ const ( // If this is a problem for your application then use NewCBCEncrypterCloser instead. // // If that is not possible then adding calls to runtime.GC() may help. -func (key *PKCS11SecretKey) NewCBCEncrypter(iv []byte) (bm cipher.BlockMode, err error) { +func (key *SecretKey) NewCBCEncrypter(iv []byte) (cipher.BlockMode, error) { return key.newBlockModeCloser(key.Cipher.CBCMech, modeEncrypt, iv, true) } @@ -68,7 +66,7 @@ func (key *PKCS11SecretKey) NewCBCEncrypter(iv []byte) (bm cipher.BlockMode, err // If this is a problem for your application then use NewCBCDecrypterCloser instead. // // If that is not possible then adding calls to runtime.GC() may help. -func (key *PKCS11SecretKey) NewCBCDecrypter(iv []byte) (bm cipher.BlockMode, err error) { +func (key *SecretKey) NewCBCDecrypter(iv []byte) (cipher.BlockMode, error) { return key.newBlockModeCloser(key.Cipher.CBCMech, modeDecrypt, iv, true) } @@ -77,7 +75,7 @@ func (key *PKCS11SecretKey) NewCBCDecrypter(iv []byte) (bm cipher.BlockMode, err // // Use of NewCBCEncrypterCloser rather than NewCBCEncrypter represents a commitment to call the Close() method // of the returned BlockModeCloser. -func (key *PKCS11SecretKey) NewCBCEncrypterCloser(iv []byte) (bmc BlockModeCloser, err error) { +func (key *SecretKey) NewCBCEncrypterCloser(iv []byte) (BlockModeCloser, error) { return key.newBlockModeCloser(key.Cipher.CBCMech, modeEncrypt, iv, false) } @@ -86,14 +84,14 @@ func (key *PKCS11SecretKey) NewCBCEncrypterCloser(iv []byte) (bmc BlockModeClose // // Use of NewCBCDecrypterCloser rather than NewCBCEncrypter represents a commitment to call the Close() method // of the returned BlockModeCloser. -func (key *PKCS11SecretKey) NewCBCDecrypterCloser(iv []byte) (bmc BlockModeCloser, err error) { +func (key *SecretKey) NewCBCDecrypterCloser(iv []byte) (BlockModeCloser, error) { return key.newBlockModeCloser(key.Cipher.CBCMech, modeDecrypt, iv, false) } // blockModeCloser is a concrete implementation of BlockModeCloser supporting CBC. type blockModeCloser struct { // PKCS#11 session to use - session *PKCS11Session + session *pkcs11Session // Cipher block size blockSize int @@ -106,48 +104,40 @@ type blockModeCloser struct { } // newBlockModeCloser creates a new blockModeCloser for the chosen mechanism and mode. -func (key *PKCS11SecretKey) newBlockModeCloser(mech uint, mode int, iv []byte, setFinalizer bool) (bmc *blockModeCloser, err error) { - // TODO maybe refactor with withSession() - sessionPool := pool.Get(key.Slot) - if sessionPool == nil { - err = fmt.Errorf("crypto11: no session for slot %d", key.Slot) - return - } - ctx := context.Background() - if instance.cfg.PoolWaitTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), instance.cfg.PoolWaitTimeout) - defer cancel() - } - var session pools.Resource - if session, err = sessionPool.Get(ctx); err != nil { - return +func (key *SecretKey) newBlockModeCloser(mech uint, mode int, iv []byte, setFinalizer bool) (*blockModeCloser, error) { + + session, err := key.context.getSession() + if err != nil { + return nil, err } - bmc = &blockModeCloser{ - session: session.(*PKCS11Session), + + bmc := &blockModeCloser{ + session: session, blockSize: key.Cipher.BlockSize, mode: mode, cleanup: func() { - sessionPool.Put(session) + key.context.pool.Put(session) }, } mechDescription := []*pkcs11.Mechanism{pkcs11.NewMechanism(mech, iv)} + switch mode { case modeDecrypt: - err = bmc.session.Ctx.DecryptInit(bmc.session.Handle, mechDescription, key.Handle) + err = session.ctx.DecryptInit(session.handle, mechDescription, key.handle) case modeEncrypt: - err = bmc.session.Ctx.EncryptInit(bmc.session.Handle, mechDescription, key.Handle) + err = session.ctx.EncryptInit(bmc.session.handle, mechDescription, key.handle) default: panic("unexpected mode") } if err != nil { bmc.cleanup() - return + return nil, err } if setFinalizer { runtime.SetFinalizer(bmc, finalizeBlockModeCloser) } - return + + return bmc, nil } func finalizeBlockModeCloser(obj interface{}) { @@ -169,9 +159,9 @@ func (bmc *blockModeCloser) CryptBlocks(dst, src []byte) { var err error switch bmc.mode { case modeDecrypt: - result, err = bmc.session.Ctx.DecryptUpdate(bmc.session.Handle, src) + result, err = bmc.session.ctx.DecryptUpdate(bmc.session.handle, src) case modeEncrypt: - result, err = bmc.session.Ctx.EncryptUpdate(bmc.session.Handle, src) + result, err = bmc.session.ctx.EncryptUpdate(bmc.session.handle, src) } if err != nil { panic(err) @@ -194,9 +184,9 @@ func (bmc *blockModeCloser) Close() { var err error switch bmc.mode { case modeDecrypt: - result, err = bmc.session.Ctx.DecryptFinal(bmc.session.Handle) + result, err = bmc.session.ctx.DecryptFinal(bmc.session.handle) case modeEncrypt: - result, err = bmc.session.Ctx.EncryptFinal(bmc.session.Handle) + result, err = bmc.session.ctx.EncryptFinal(bmc.session.handle) } bmc.session = nil bmc.cleanup() diff --git a/close_test.go b/close_test.go index 5320599..e5d88c5 100644 --- a/close_test.go +++ b/close_test.go @@ -22,48 +22,92 @@ package crypto11 import ( - "crypto" "crypto/dsa" + "crypto/elliptic" "fmt" + "math/rand" "testing" + "time" + + "github.com/stretchr/testify/require" ) func TestClose(t *testing.T) { // Verify that close and re-open works. - var err error - var key *PKCS11PrivateKeyDSA - if _, err := ConfigureFromFile("config"); err != nil { - t.Fatal(err) - } - psize := dsa.L1024N160 - if key, err = GenerateDSAKeyPair(dsaSizes[psize]); err != nil { - t.Errorf("crypto11.GenerateDSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateDSAKeyPair: returned nil but no error") - return - } - var id []byte - if id, _, err = key.Identify(); err != nil { - t.Errorf("crypto11.dsa.PKCS11PrivateKeyDSA.Identify: %v", err) - return - } - if err = Close(); err != nil { - t.Fatal(err) - } + + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + const pSize = dsa.L1024N160 + id := randomBytes() + key, err := ctx.GenerateDSAKeyPair(id, dsaSizes[pSize]) + require.NoError(t, err) + require.NotNil(t, key) + + require.NoError(t, ctx.Close()) + for i := 0; i < 5; i++ { - if _, err := ConfigureFromFile("config"); err != nil { - t.Fatal(err) - } - var key2 crypto.PrivateKey - if key2, err = FindKeyPair(id, nil); err != nil { - t.Errorf("crypto11.dsa.FindDSAKeyPair by id: %v", err) - return - } - testDsaSigning(t, key2.(*PKCS11PrivateKeyDSA), psize, fmt.Sprintf("close%d", i)) - if err = Close(); err != nil { - t.Fatal(err) - } + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + key2, err := ctx.FindKeyPair(id, nil) + require.NoError(t, err) + + testDsaSigning(t, key2.(*pkcs11PrivateKeyDSA), pSize, fmt.Sprintf("close%d", i)) + require.NoError(t, ctx.Close()) } } + +// randomBytes returns 32 random bytes. +func randomBytes() []byte { + result := make([]byte, 32) + rand.Read(result) + return result +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func TestErrorAfterClosed(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + err = ctx.Close() + require.NoError(t, err) + + bytes := randomBytes() + + _, err = ctx.FindKey(bytes, nil) + require.Equal(t, errClosed, err) + + _, err = ctx.FindKeyPair(bytes, nil) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateSecretKey(bytes, 256, CipherAES) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateSecretKeyWithLabel(bytes, bytes, 256, CipherAES) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateRSAKeyPair(bytes, 2048) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateRSAKeyPairWithLabel(bytes, bytes, 2048) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateDSAKeyPair(bytes, dsaSizes[dsa.L1024N160]) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateDSAKeyPairWithLabel(bytes, bytes, dsaSizes[dsa.L1024N160]) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateECDSAKeyPair(bytes, elliptic.P224()) + require.Equal(t, errClosed, err) + + _, err = ctx.GenerateECDSAKeyPairWithLabel(bytes, bytes, elliptic.P224()) + require.Equal(t, errClosed, err) + + _, err = ctx.NewRandomReader() + require.Equal(t, errClosed, err) +} diff --git a/common.go b/common.go index 68c955c..2aaac8f 100644 --- a/common.go +++ b/common.go @@ -1,26 +1,36 @@ +// Copyright 2017 Thales e-Security, Inc +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + package crypto11 import ( "C" "encoding/asn1" - "encoding/hex" - "errors" "math/big" "unsafe" "github.com/miekg/pkcs11" + "github.com/pkg/errors" ) -// ErrMalformedDER represents a failure to decode an ASN.1-encoded message -var ErrMalformedDER = errors.New("crypto11: malformed DER message") - -// ErrMalformedSignature represents a failure to decode a signature. This -// means the PKCS#11 library has returned an empty or odd-length byte -// string. -var ErrMalformedSignature = errors.New("crypto11xo: malformed signature") - -const labelLength = 64 - func ulongToBytes(n uint) []byte { return C.GoBytes(unsafe.Pointer(&n), C.sizeof_ulong) // ugh! } @@ -50,7 +60,7 @@ type dsaSignature struct { // Populate a dsaSignature from a raw byte sequence func (sig *dsaSignature) unmarshalBytes(sigBytes []byte) error { if len(sigBytes) == 0 || len(sigBytes)%2 != 0 { - return ErrMalformedSignature + return errors.New("DSA signature length is invalid from token") } n := len(sigBytes) / 2 sig.R, sig.S = new(big.Int), new(big.Int) @@ -62,9 +72,9 @@ func (sig *dsaSignature) unmarshalBytes(sigBytes []byte) error { // Populate a dsaSignature from DER encoding func (sig *dsaSignature) unmarshalDER(sigDER []byte) error { if rest, err := asn1.Unmarshal(sigDER, sig); err != nil { - return err + return errors.WithMessage(err, "DSA signature contains invalid ASN.1 data") } else if len(rest) > 0 { - return ErrMalformedDER + return errors.New("unexpected data found after DSA signature") } return nil } @@ -75,16 +85,16 @@ func (sig *dsaSignature) marshalDER() ([]byte, error) { } // Compute *DSA signature and marshal the result in DER form -func dsaGeneric(slot uint, key pkcs11.ObjectHandle, mechanism uint, digest []byte) ([]byte, error) { +func (c *Context) dsaGeneric(key pkcs11.ObjectHandle, mechanism uint, digest []byte) ([]byte, error) { var err error var sigBytes []byte var sig dsaSignature mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(mechanism, nil)} - err = withSession(slot, func(session *PKCS11Session) error { - if err = instance.ctx.SignInit(session.Handle, mech, key); err != nil { + err = c.withSession(func(session *pkcs11Session) error { + if err = c.ctx.SignInit(session.handle, mech, key); err != nil { return err } - sigBytes, err = instance.ctx.Sign(session.Handle, digest) + sigBytes, err = c.ctx.Sign(session.handle, digest) return err }) if err != nil { @@ -97,19 +107,3 @@ func dsaGeneric(slot uint, key pkcs11.ObjectHandle, mechanism uint, digest []byt return sig.marshalDER() } - -// Pick a random label for a key -func generateKeyLabel() ([]byte, error) { - rawLabel := make([]byte, labelLength / 2) - var rand PKCS11RandReader - sz, err := rand.Read(rawLabel) - if err != nil { - return nil, err - } - if sz < len(rawLabel) { - return nil, ErrCannotGetRandomData - } - label := make([]byte, labelLength) - hex.Encode(label, rawLabel) - return label, nil -} diff --git a/common_test.go b/common_test.go deleted file mode 100644 index 0c56825..0000000 --- a/common_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package crypto11 - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func TestGenerateKeyLabel(t *testing.T) { - _, err := ConfigureFromFile("config") - require.NoError(t, err) - - for i :=0; i < 100; i++ { - label, err := generateKeyLabel() - require.NoError(t, err) - require.Len(t, label, labelLength) - for _, b := range label { - require.NotEqual(t, byte(0), b) - } - } -} diff --git a/config b/config index 1550b31..c51f9c5 100644 --- a/config +++ b/config @@ -1,5 +1,5 @@ { - "Path" : "/opt/nfast/toolkits/pkcs11/libcknfast.so", - "TokenLabel": "rjk", + "Path" : "/usr/lib/softhsm/libsofthsm2.so", + "TokenLabel": "token1", "Pin" : "password" } diff --git a/configs/config.nshield b/configs/config.nshield deleted file mode 100644 index 1550b31..0000000 --- a/configs/config.nshield +++ /dev/null @@ -1,5 +0,0 @@ -{ - "Path" : "/opt/nfast/toolkits/pkcs11/libcknfast.so", - "TokenLabel": "rjk", - "Pin" : "password" -} diff --git a/configs/config.softhsm b/configs/config.softhsm deleted file mode 100644 index 2d37243..0000000 --- a/configs/config.softhsm +++ /dev/null @@ -1,5 +0,0 @@ -{ - "Path" : "/usr/lib/softhsm/libsofthsm.so", - "TokenLabel": "test", - "Pin" : "password" -} diff --git a/configs/config.softhsm2 b/configs/config.softhsm2 deleted file mode 100644 index 7795394..0000000 --- a/configs/config.softhsm2 +++ /dev/null @@ -1,5 +0,0 @@ -{ - "Path" : "/usr/lib/softhsm/libsofthsm2.so", - "TokenLabel": "test", - "Pin" : "password" -} diff --git a/crypto11.go b/crypto11.go index 9b49827..f881461 100644 --- a/crypto11.go +++ b/crypto11.go @@ -1,4 +1,4 @@ -// Copyright 2016, 2017 Thales e-Security, Inc +// Copyright 2016 Thales e-Security, Inc // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -21,24 +21,26 @@ // Package crypto11 enables access to cryptographic keys from PKCS#11 using Go crypto API. // -// Simple use +// Configuration // -// 1. Either write a configuration file (see ConfigureFromFile) or -// define a configuration in your application (see PKCS11Config and -// Configure). This will identify the PKCS#11 library and token to -// use, and contain the password (or "PIN" in PKCS#11 terminology) to -// use if the token requires login. +// PKCS#11 tokens are accessed via Context objects. Each Context connects to one token. // -// 2. Create keys with GenerateDSAKeyPair, GenerateRSAKeyPair and -// GenerateECDSAKeyPair. The keys you get back implement the standard -// Go crypto.Signer interface (and crypto.Decrypter, for RSA). They -// are automatically persisted under random a randomly generated label -// and ID (use the Identify method to discover them). +// Context objects are created by calling Configure or ConfigureFromFile. +// In the latter case, the file should contain a JSON representation of +// a Config. // -// 3. Retrieve existing keys with FindKeyPair. The return value is a -// Go crypto.PrivateKey; it may be converted either to crypto.Signer -// or to *PKCS11PrivateKeyDSA, *PKCS11PrivateKeyECDSA or -// *PKCS11PrivateKeyRSA. +// Key Generation and Usage +// +// There is support for generating DSA, RSA and ECDSA keys. These keys +// can be found later using FindKeyPair. All three key types implement +// the crypto.Signer interface and the RSA keys also implement crypto.Decrypter. +// +// RSA keys obtained through FindKeyPair will need a type assertion to be +// used for decryption. Assert either crypto.Decrypter or SignerDecrypter, as you +// prefer. +// +// Symmetric keys can also be generated. These are found later using FindKey. +// See the documentation for SecretKey for further information. // // Sessions and concurrency // @@ -47,22 +49,30 @@ // nothing of this and expect to be able to sign from multiple threads // without constraint. We address this as follows. // -// 1. PKCS11Object captures both the object handle and the slot ID -// for an object. +// 1. When a Context is created, a session is created and the user is +// logged in. This session remains open until the Context is closed, +// to ensure all object handles remain valid and to avoid repeatedly +// calling C_Login. // -// 2. For each slot we maintain a pool of read-write sessions. The -// pool expands dynamically up to an (undocumented) limit. +// 2. The Context also maintains a pool of read-write sessions. The pool expands +// dynamically as needed, but never beyond the maximum number of r/w sessions +// supported by the token (as reported by C_GetInfo). If other applications +// are using the token, a lower limit should be set in the Config. // // 3. Each operation transiently takes a session from the pool. They // have exclusive use of the session, meeting PKCS#11's concurrency -// requirements. +// requirements. Sessions are returned to the pool afterwards and may +// be re-used. // -// The details are, partially, exposed in the API; since the target -// use case is PKCS#11-unaware operation it may be that the API as it -// stands isn't good enough for PKCS#11-aware applications. Feedback -// welcome. +// Behaviour of the pool can be tweaked via Config fields: // -// See also https://golang.org/pkg/crypto/ +// - PoolWaitTimeout controls how long an operation can block waiting on a +// session from the pool. A zero value means there is no limit. Timeouts +// occur if the pool is fully used and additional operations are requested. +// +// - MaxSessions sets an upper bound on the number of sessions. If this value is zero, +// a default maximum is used (see DefaultMaxSessions). In every case the maximum +// supported sessions as reported by the token is obeyed. // // Limitations // @@ -80,217 +90,273 @@ package crypto11 import ( "crypto" "encoding/json" - "errors" - "fmt" - "log" + "io" "os" "time" + "github.com/vitessio/vitess/go/sync2" + "github.com/miekg/pkcs11" + "github.com/pkg/errors" + "github.com/vitessio/vitess/go/pools" ) const ( // DefaultMaxSessions controls the maximum number of concurrent sessions to - // open, unless otherwise specified in the PKCS11Config object. + // open, unless otherwise specified in the Config object. DefaultMaxSessions = 1024 ) -// ErrTokenNotFound represents the failure to find the requested PKCS#11 token -var ErrTokenNotFound = errors.New("crypto11: could not find PKCS#11 token") +// errTokenNotFound represents the failure to find the requested PKCS#11 token +var errTokenNotFound = errors.New("could not find PKCS#11 token") -// ErrKeyNotFound represents the failure to find the requested PKCS#11 key -var ErrKeyNotFound = errors.New("crypto11: could not find PKCS#11 key") +// errClosed is returned if a Context is used after a call to Close. +var errClosed = errors.New("cannot used closed Context") -// ErrNotConfigured is returned when the PKCS#11 library is not configured -var ErrNotConfigured = errors.New("crypto11: PKCS#11 not yet configured") +// errAmbiguousToken is returned if the supplied Config specifies more than one way to select the token. +var errAmbiguousToken = errors.New("config must only specify one way to select a token") -// ErrCannotOpenPKCS11 is returned when the PKCS#11 library cannot be opened -var ErrCannotOpenPKCS11 = errors.New("crypto11: could not open PKCS#11") - -// ErrCannotGetRandomData is returned when the PKCS#11 library fails to return enough random data -var ErrCannotGetRandomData = errors.New("crypto11: cannot get random data from PKCS#11") - -// ErrUnsupportedKeyType is returned when the PKCS#11 library returns a key type that isn't supported -var ErrUnsupportedKeyType = errors.New("crypto11: unrecognized key type") - -// PKCS11Object contains a reference to a loaded PKCS#11 object. -type PKCS11Object struct { +// pkcs11Object contains a reference to a loaded PKCS#11 object. +type pkcs11Object struct { // The PKCS#11 object handle. - Handle pkcs11.ObjectHandle + handle pkcs11.ObjectHandle - // The PKCS#11 slot number. - // - // This is used internally to find a session handle that can + // The PKCS#11 context. This is used to find a session handle that can // access this object. - Slot uint + context *Context +} + +func (o *pkcs11Object) Delete() error { + return o.context.withSession(func(session *pkcs11Session) error { + err := session.ctx.DestroyObject(session.handle, o.handle) + return errors.WithMessage(err, "failed to destroy key") + }) } -// PKCS11PrivateKey contains a reference to a loaded PKCS#11 private key object. -type PKCS11PrivateKey struct { - PKCS11Object +// pkcs11PrivateKey contains a reference to a loaded PKCS#11 private key object. +type pkcs11PrivateKey struct { + pkcs11Object - // The corresponding public key - PubKey crypto.PublicKey + // pubKeyHandle is a handle to the public key. + pubKeyHandle pkcs11.ObjectHandle + + // pubKey is an exported copy of the public key. We pre-export the key material because crypto.Signer.Public + // doesn't allow us to return errors. + pubKey crypto.PublicKey } -// In a former design we carried around the object handle for the -// public key and retrieved it on demand. The downside of that is -// that the Public() method on Signer &c has no way to communicate -// errors. - -/* Nasty globals */ -var instance = &libCtx{ - cfg: &PKCS11Config{ - MaxSessions: DefaultMaxSessions, - IdleTimeout: 0, - PoolWaitTimeout: 0, - }, +// Delete implements Signer.Delete. +func (k *pkcs11PrivateKey) Delete() error { + err := k.pkcs11Object.Delete() + if err != nil { + return err + } + + return k.context.withSession(func(session *pkcs11Session) error { + err := session.ctx.DestroyObject(session.handle, k.pubKeyHandle) + return errors.WithMessage(err, "failed to destroy public key") + }) } -// Represent library pkcs11 context and token configuration -type libCtx struct { +// A Context stores the connection state to a PKCS#11 token. Use Configure or ConfigureFromFile to create a new +// Context. Call Close when finished with the token, to free up resources. +// +// All functions, except Close, are safe to call from multiple goroutines. +type Context struct { + // Atomic fields must be at top (according to the package owners) + closed sync2.AtomicBool + ctx *pkcs11.Ctx - cfg *PKCS11Config + cfg *Config token *pkcs11.TokenInfo slot uint + pool *pools.ResourcePool + + // persistentSession is a session held open so we can be confident handles and login status + // persist for the duration of this context + persistentSession pkcs11.SessionHandle +} + +// Signer is a PKCS#11 key that implements crypto.Signer. +type Signer interface { + crypto.Signer + + // Delete deletes the key pair from the token. + Delete() error } -// Find a token given its serial number -func findToken(slots []uint, serial string, label string) (uint, *pkcs11.TokenInfo, error) { +// SignerDecrypter is a PKCS#11 key implements crypto.Signer and crypto.Decrypter. +type SignerDecrypter interface { + Signer + + // Decrypt implements crypto.Decrypter. + Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) +} + +// findToken finds a token given exactly one of serial, label or slotNumber +func (c *Context) findToken(slots []uint, serial, label string, slotNumber *int) (uint, *pkcs11.TokenInfo, error) { for _, slot := range slots { - tokenInfo, err := instance.ctx.GetTokenInfo(slot) + + tokenInfo, err := c.ctx.GetTokenInfo(slot) if err != nil { return 0, nil, err } - if tokenInfo.SerialNumber == serial { - return slot, &tokenInfo, nil - } - if tokenInfo.Label == label { + + if (slotNumber != nil && uint(*slotNumber) == slot) || + (tokenInfo.SerialNumber != "" && tokenInfo.SerialNumber == serial) || + (tokenInfo.Label != "" && tokenInfo.Label == label) { + return slot, &tokenInfo, nil } + } - return 0, nil, ErrTokenNotFound + return 0, nil, errTokenNotFound } -// PKCS11Config holds PKCS#11 configuration information. +// Config holds PKCS#11 configuration information. // -// A token may be identified either by serial number or label. If -// both are specified then the first match wins. +// A token may be selected by label, serial number or slot number. It is an error to specify +// more than one way to select the token. // // Supply this to Configure(), or alternatively use ConfigureFromFile(). -type PKCS11Config struct { - // Full path to PKCS#11 library +type Config struct { + // Full path to PKCS#11 library. Path string - // Token serial number + // Token serial number. TokenSerial string - // Token label + // Token label. TokenLabel string - // User PIN (password) + // SlotNumber identifies a token to use by the slot containing it. + SlotNumber *int + + // User PIN (password). Pin string - // Maximum number of concurrent sessions to open + // Maximum number of concurrent sessions to open. If zero, DefaultMaxSessions is used. MaxSessions int - // Session idle timeout to be evicted from the pool - IdleTimeout time.Duration - - // Maximum time allowed to wait a sessions pool for a session + // Maximum time to wait for a session from the sessions pool. Zero means wait indefinitely. PoolWaitTimeout time.Duration } -// Configure configures PKCS#11 from a PKCS11Config. -// -// The PKCS#11 library context is returned, -// allowing a PKCS#11-aware application to make use of it. Non-aware -// appliations may ignore it. -// -// Unsually, these values may be present even if the error is -// non-nil. This corresponds to the case that the library has already -// been configured. Note that it is NOT reconfigured so if you supply -// a different configuration the second time, it will be ignored in -// favor of the first configuration. -// -// If config is nil, and the library has already been configured, the -// context from the first configuration is returned (and -// the error will be nil in this case). -func Configure(config *PKCS11Config) (*pkcs11.Ctx, error) { - var err error - var slots []uint - - if config == nil { - if instance.ctx != nil { - return instance.ctx, nil - } - return nil, ErrNotConfigured +// Configure creates a new Context based on the supplied PKCS#11 configuration. +func Configure(config *Config) (*Context, error) { + // Have we been given exactly one way to select a token? + count := 0 + if config.SlotNumber != nil { + count++ } - if instance.ctx != nil { - log.Printf("PKCS#11 library already configured") - return instance.ctx, nil + if config.TokenLabel != "" { + count++ + } + if config.TokenSerial != "" { + count++ + } + if count != 1 { + return nil, errAmbiguousToken } if config.MaxSessions == 0 { config.MaxSessions = DefaultMaxSessions } - instance.cfg = config - instance.ctx = pkcs11.New(config.Path) + + instance := &Context{ + cfg: config, + ctx: pkcs11.New(config.Path), + } + if instance.ctx == nil { - log.Printf("Could not open PKCS#11 library: %s", config.Path) - return nil, ErrCannotOpenPKCS11 + return nil, errors.New("could not open PKCS#11") } - if err = instance.ctx.Initialize(); err != nil { - log.Printf("Failed to initialize PKCS#11 library: %s", err.Error()) - return nil, err + if err := instance.ctx.Initialize(); err != nil { + instance.ctx.Destroy() + return nil, errors.WithMessage(err, "failed to initialize PKCS#11 library") } - if slots, err = instance.ctx.GetSlotList(true); err != nil { - log.Printf("Failed to list PKCS#11 Slots: %s", err.Error()) - return nil, err + slots, err := instance.ctx.GetSlotList(true) + if err != nil { + _ = instance.ctx.Finalize() + instance.ctx.Destroy() + return nil, errors.WithMessage(err, "failed to list PKCS#11 slots") } - instance.slot, instance.token, err = findToken(slots, config.TokenSerial, config.TokenLabel) + instance.slot, instance.token, err = instance.findToken(slots, config.TokenSerial, config.TokenLabel, config.SlotNumber) if err != nil { - log.Printf("Failed to find Token in any Slot: %s", err.Error()) + _ = instance.ctx.Finalize() + instance.ctx.Destroy() return nil, err } - if instance.token.MaxRwSessionCount > 0 && uint(instance.cfg.MaxSessions) > instance.token.MaxRwSessionCount { - return nil, fmt.Errorf("crypto11: provided max sessions value (%d) exceeds max value the token supports (%d)", instance.cfg.MaxSessions, instance.token.MaxRwSessionCount) + // Create the session pool. + maxSessions := instance.cfg.MaxSessions + tokenMaxSessions := instance.token.MaxRwSessionCount + if tokenMaxSessions != pkcs11.CK_EFFECTIVELY_INFINITE && tokenMaxSessions != pkcs11.CK_UNAVAILABLE_INFORMATION { + maxSessions = min(maxSessions, castDown(tokenMaxSessions)) } - if err := setupSessions(instance, instance.slot); err != nil { - return nil, err + // We will use one session to keep state alive, so the pool gets maxSessions - 1 + instance.pool = pools.NewResourcePool(instance.resourcePoolFactoryFunc, maxSessions-1, maxSessions-1, 0) + + // Create a long-term session and log it in. This session won't be used by callers, instead it is used to keep + // a connection alive to the token to ensure object handles and the log in status remain accessible. + instance.persistentSession, err = instance.ctx.OpenSession(instance.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + if err != nil { + _ = instance.ctx.Finalize() + instance.ctx.Destroy() + return nil, errors.WithMessagef(err, "failed to create long term session") + } + err = instance.ctx.Login(instance.persistentSession, pkcs11.CKU_USER, instance.cfg.Pin) + if err != nil { + _ = instance.ctx.Finalize() + instance.ctx.Destroy() + return nil, errors.WithMessagef(err, "failed to log into long term session") } - // login required if a pool evict idle sessions (handled by the pool) or - // for the first connection in the pool (handled here) - if instance.cfg.IdleTimeout == 0 { - if instance.token.Flags&pkcs11.CKF_LOGIN_REQUIRED != 0 && instance.cfg.Pin != "" { - if err = withSession(instance.slot, loginToken); err != nil { - return nil, err - } - } + return instance, nil +} + +func min(a, b int) int { + if b < a { + return b + } + return a +} + +// castDown returns orig as a signed integer. If an overflow would have occurred, +// the maximum possible value is returned. +func castDown(orig uint) int { + // From https://stackoverflow.com/a/6878625/474189 + const maxUint = ^uint(0) + const maxInt = int(maxUint >> 1) + + if orig > uint(maxInt) { + return maxInt } - return instance.ctx, nil + return int(orig) } -// ConfigureFromFile configures PKCS#11 from a name configuration file. -// -// Configuration files are a JSON representation of the PKCSConfig object. -// The return value is as for Configure(). -// -// Note that if CRYPTO11_CONFIG_PATH is set in the environment, -// configuration will be read from that file, overriding any later -// runtime configuration. -func ConfigureFromFile(configLocation string) (ctx *pkcs11.Ctx, err error) { - file, err := os.Open(configLocation) +// ConfigureFromFile is a convenience method, which parses the configuration file +// and calls Configure. The configuration file should be a JSON representation +// of a Config object. +func ConfigureFromFile(configLocation string) (*Context, error) { + config, err := loadConfigFromFile(configLocation) if err != nil { - log.Printf("Could not open config file: %s", configLocation) return nil, err } + + return Configure(config) +} + +// loadConfigFromFile reads a Config struct from a file. +func loadConfigFromFile(configLocation string) (*Config, error) { + file, err := os.Open(configLocation) + if err != nil { + return nil, errors.WithMessagef(err, "could not open config file: %s", configLocation) + } defer func() { closeErr := file.Close() if err == nil { @@ -299,50 +365,28 @@ func ConfigureFromFile(configLocation string) (ctx *pkcs11.Ctx, err error) { }() configDecoder := json.NewDecoder(file) - config := &PKCS11Config{} + config := &Config{} err = configDecoder.Decode(config) - if err != nil { - log.Printf("Could decode config file: %s", err.Error()) - return nil, err - } - return Configure(config) + return config, errors.WithMessage(err, "could decode config file:") } -// Close releases all sessions and uninitializes library default handle. -// Once library handle is released, library may be configured once again. -func Close() error { - ctx := instance.ctx - if ctx != nil { - slots, err := ctx.GetSlotList(true) - if err != nil { - return err - } +// Close releases all the resources used by the Context and unloads the PKCS #11 library. Close blocks until existing +// operations have finished. A closed Context cannot be reused. +func (c *Context) Close() error { + c.closed.Set(true) - for _, slot := range slots { - if err := pool.closeSessions(slot); err != nil && err != errPoolNotFound { - return err - } - // if something by passed cache - if err := ctx.CloseAllSessions(slot); err != nil { - return err - } - } + // Block until all resources returned to pool + c.pool.Close() - if err := ctx.Finalize(); err != nil { - return err - } + // Close our long-term session. We ignore any returned error, + // since we plan to kill our collection to the library anyway. + _ = c.ctx.CloseSession(c.persistentSession) - ctx.Destroy() - instance.ctx = nil + err := c.ctx.Finalize() + if err != nil { + return err } + c.ctx.Destroy() return nil } - -func init() { - if configLocation, ok := os.LookupEnv("CRYPTO11_CONFIG_PATH"); ok { - if _, err := ConfigureFromFile(configLocation); err != nil { - panic(err) - } - } -} diff --git a/crypto11_test.go b/crypto11_test.go index 1a56ec2..fc84ff5 100644 --- a/crypto11_test.go +++ b/crypto11_test.go @@ -22,197 +22,55 @@ package crypto11 import ( - "crypto" "crypto/dsa" "encoding/json" "fmt" - "github.com/miekg/pkcs11" - "github.com/stretchr/testify/require" "log" + "math/rand" "os" "testing" "time" -) - -func TestInitializeFromConfig(t *testing.T) { - var config PKCS11Config - config.Path = "NoSuchFile" - config.Pin = "NoSuchPin" - config.TokenSerial = "NoSuchToken" - config.TokenLabel = "NoSuchToken" - //assert.Panics(Configure(config), "Invalid config should panic") - _, err := ConfigureFromFile("config") - require.NoError(t, err) - require.NoError(t, Close()) -} -func TestLoginContext(t *testing.T) { - t.Run("key identity with login", func(t *testing.T) { - _, err := configureWithPin(t) - require.NoError(t, err) + "github.com/stretchr/testify/assert" - defer func() { - err = Close() - require.NoError(t, err) - }() - - // Generate a key and and close a session - var key *PKCS11PrivateKeyDSA - psize := dsa.L1024N160 - if key, err = GenerateDSAKeyPair(dsaSizes[psize]); err != nil { - t.Errorf("crypto11.GenerateDSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateDSAKeyPair: returned nil but no error") - return - } - - var id []byte - if id, _, err = key.Identify(); err != nil { - t.Errorf("crypto11.dsa.PKCS11PrivateKeyDSA.Identify: %v", err) - return - } - if err = Close(); err != nil { - t.Fatal(err) - } - - // Reopen a session and try to find a key. - // Valid session must enlist a key. - // If login is not performed than it will fail. - _, err = configureWithPin(t) - require.NoError(t, err) - - var key2 crypto.PrivateKey - if key2, err = FindKeyPair(id, nil); err != nil { - t.Errorf("crypto11.dsa.FindDSAKeyPair by id: %v", err) - return - } - testDsaSigning(t, key2.(*PKCS11PrivateKeyDSA), psize, fmt.Sprintf("close%d", 0)) - }) - - t.Run("key identity with expiration", func(t *testing.T) { - prevIdleTimeout := instance.cfg.IdleTimeout - defer func() {instance.cfg.IdleTimeout = prevIdleTimeout}() - instance.cfg.IdleTimeout = time.Second - - _, err := configureWithPin(t) - require.NoError(t, err) - - defer func() { - err = Close() - require.NoError(t, err) - }() - - // Generate a key and and close a session - var key *PKCS11PrivateKeyDSA - psize := dsa.L1024N160 - if key, err = GenerateDSAKeyPair(dsaSizes[psize]); err != nil { - t.Errorf("crypto11.GenerateDSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateDSAKeyPair: returned nil but no error") - return - } - - var id []byte - if id, _, err = key.Identify(); err != nil { - t.Errorf("crypto11.dsa.PKCS11PrivateKeyDSA.Identify: %v", err) - return - } - - // kick out all cfg.Idle sessions - time.Sleep(instance.cfg.IdleTimeout + time.Second) - - var key2 crypto.PrivateKey - if key2, err = FindKeyPair(id, nil); err != nil { - t.Errorf("crypto11.dsa.FindDSAKeyPair by id: %v", err) - return - } - testDsaSigning(t, key2.(*PKCS11PrivateKeyDSA), psize, fmt.Sprintf("close%d", 0)) - }) - - t.Run("login context shared between sessions", func(t *testing.T) { - _, err := configureWithPin(t) - require.NoError(t, err) - - defer func() { - err = Close() - require.NoError(t, err) - }() - - // Generate a key and and close a session - var key *PKCS11PrivateKeyDSA - psize := dsa.L1024N160 - if key, err = GenerateDSAKeyPair(dsaSizes[psize]); err != nil { - t.Errorf("crypto11.GenerateDSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateDSAKeyPair: returned nil but no error") - return - } - - var id []byte - if id, _, err = key.Identify(); err != nil { - t.Errorf("crypto11.dsa.PKCS11PrivateKeyDSA.Identify: %v", err) - return - } - - if err = withSession(instance.slot, func(s1 *PKCS11Session) error { - return withSession(instance.slot, func(s2 *PKCS11Session) error { - var key2 crypto.PrivateKey - if key2, err = FindKeyPair(id, nil); err != nil { - t.Errorf("crypto11.dsa.FindDSAKeyPair by id: %v", err) - return nil - } - testDsaSigning(t, key2.(*PKCS11PrivateKeyDSA), psize, fmt.Sprintf("close%d", 0)) - return nil - }) - }); err != nil { - t.Errorf("with session failed: %s", err.Error()) - return - } - }) -} - -func TestIdentityExpiration(t *testing.T) { - prevIdleTimeout := instance.cfg.IdleTimeout - defer func() {instance.cfg.IdleTimeout = prevIdleTimeout}() - instance.cfg.IdleTimeout = time.Second + "github.com/stretchr/testify/require" +) - _, err := configureWithPin(t) +func TestKeysPersistAcrossContexts(t *testing.T) { + ctx, err := configureWithPin(t) require.NoError(t, err) defer func() { - err = Close() + err = ctx.Close() require.NoError(t, err) }() // Generate a key and and close a session - var key *PKCS11PrivateKeyDSA - psize := dsa.L1024N160 - if key, err = GenerateDSAKeyPair(dsaSizes[psize]); err != nil { - t.Errorf("crypto11.GenerateDSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateDSAKeyPair: returned nil but no error") - return - } + const pSize = dsa.L1024N160 + id := randomBytes() + key, err := ctx.GenerateDSAKeyPair(id, dsaSizes[pSize]) + require.NoError(t, err) + require.NotNil(t, key) - // kick out all cfg.Idle sessions - time.Sleep(instance.cfg.IdleTimeout + time.Second) + err = ctx.Close() + require.NoError(t, err) - if _, _, err = key.Identify(); err != nil { - if perr, ok := err.(pkcs11.Error); !ok || perr != pkcs11.CKR_OBJECT_HANDLE_INVALID { - t.Fatal("failed to generate a key, unexpected error:", err) - } - } + // Reopen a session and try to find a key. + // Valid session must enlist a key. + // If login is not performed than it will fail. + ctx, err = configureWithPin(t) + require.NoError(t, err) + + key2, err := ctx.FindKeyPair(id, nil) + require.NoError(t, err) + + testDsaSigning(t, key2.(*pkcs11PrivateKeyDSA), pSize, fmt.Sprintf("close%d", 0)) + + err = key2.Delete() + require.NoError(t, err) } -func configureWithPin(t *testing.T) (*pkcs11.Ctx, error) { +func configureWithPin(t *testing.T) (*Context, error) { cfg, err := getConfig("config") if err != nil { t.Fatal(err) @@ -230,7 +88,7 @@ func configureWithPin(t *testing.T) (*pkcs11.Ctx, error) { return ctx, nil } -func getConfig(configLocation string) (ctx *PKCS11Config, err error) { +func getConfig(configLocation string) (ctx *Config, err error) { file, err := os.Open(configLocation) if err != nil { log.Printf("Could not open config file: %s", configLocation) @@ -241,7 +99,7 @@ func getConfig(configLocation string) (ctx *PKCS11Config, err error) { }() configDecoder := json.NewDecoder(file) - config := &PKCS11Config{} + config := &Config{} err = configDecoder.Decode(config) if err != nil { log.Printf("Could decode config file: %s", err.Error()) @@ -250,3 +108,109 @@ func getConfig(configLocation string) (ctx *PKCS11Config, err error) { return config, nil } +func TestKeyPairDelete(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + id := randomBytes() + key, err := ctx.GenerateRSAKeyPair(id, 2048) + require.NoError(t, err) + + // Check we can find it + _, err = ctx.FindKeyPair(id, nil) + require.NoError(t, err) + + err = key.Delete() + require.NoError(t, err) + + k, err := ctx.FindKeyPair(id, nil) + require.NoError(t, err) + require.Nil(t, k) +} + +func TestKeyDelete(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + id := randomBytes() + key, err := ctx.GenerateSecretKey(id, 128, CipherAES) + require.NoError(t, err) + + // Check we can find it + _, err = ctx.FindKey(id, nil) + require.NoError(t, err) + + err = key.Delete() + require.NoError(t, err) + + k, err := ctx.FindKey(id, nil) + require.NoError(t, err) + require.Nil(t, k) +} + +func TestAmbiguousTokenConfig(t *testing.T) { + slotNum := 1 + badConfigs := []*Config{ + {TokenSerial: "serial", TokenLabel: "label"}, + {TokenSerial: "serial", SlotNumber: &slotNum}, + {SlotNumber: &slotNum, TokenLabel: "label"}, + } + + for _, config := range badConfigs { + _, err := Configure(config) + assert.Equal(t, errAmbiguousToken, err) + } +} + +func TestSelectBySlot(t *testing.T) { + config, err := loadConfigFromFile("config") + require.NoError(t, err) + + // Look up slot number for label + ctx, err := Configure(config) + require.NoError(t, err) + + slotNumber := int(ctx.slot) + t.Logf("Using slot %d", slotNumber) + err = ctx.Close() + require.NoError(t, err) + + slotConfig := &Config{ + SlotNumber: &slotNumber, + Pin: config.Pin, + Path: config.Path, + } + + ctx, err = Configure(slotConfig) + require.NoError(t, err) + + slotNumber2 := int(ctx.slot) + err = ctx.Close() + require.NoError(t, err) + + assert.Equal(t, slotNumber, slotNumber2) +} + +func TestSelectByNonExistingSlot(t *testing.T) { + config, err := loadConfigFromFile("config") + require.NoError(t, err) + + rand.Seed(time.Now().UnixNano()) + randomSlot := int(rand.Uint32()) + + config.TokenLabel = "" + config.TokenSerial = "" + config.SlotNumber = &randomSlot + + // Look up slot number for label + _, err = Configure(config) + require.Equal(t, errTokenNotFound, err) +} diff --git a/demo/README.md b/demo/README.md deleted file mode 100644 index ba2713e..0000000 --- a/demo/README.md +++ /dev/null @@ -1,14 +0,0 @@ -Demo Program -============ - -A demo program using a PKCS#11-protected key to authenticate a web server. - -To use with nShield PKCS#11, assuming an OCS-protected key: - - generatekey -b pkcs11req protect=token type=rsa size=2048 plainname=demo \ - selfcert=yes embedsavefile=hkey.pem digest=sha256 \ - x509country=GB x509province=England x509locality=Rutland x509org=org x509orgunit=any \ - x509dnscommon=www.example.com - CKNFAST_DEBUG=2 CRYPTO11_CONFIG_PATH=../configs/config.nshield go run server.go - -`plainname` corresponds to CKA_LABEL. diff --git a/demo/server.go b/demo/server.go deleted file mode 100644 index 80a7349..0000000 --- a/demo/server.go +++ /dev/null @@ -1,93 +0,0 @@ -// A demo program using a PKCS#11-protected key to authenticate a web server. -// -// To use with nShield PKCS#11, assuming an OCS-protected key: -// -// generatekey -b pkcs11req protect=token type=rsa size=2048 plainname=demo \ -// selfcert=yes embedsavefile=hkey.pem digest=sha256 \ -// x509country=GB x509province=England x509locality=Rutland x509org=org x509orgunit=any \ -// x509dnscommon=www.example.com -// CKNFAST_DEBUG=2 CRYPTO11_CONFIG_PATH=../configs/config.nshield go run server.go -// -// 'plainname' corresponds to CKA_LABEL. -package main - -import ( - "crypto" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "encoding/pem" - "fmt" - "github.com/ThalesIgnite/crypto11" - "io/ioutil" - "log" - "net/http" -) - -var keyLabel = "demo" -var certFile = "hkey_selfcert.pem" - -func useHardwareKey(config *tls.Config, keyLabel string, certFile string) error { - var err error - var cert tls.Certificate - var certPEM []byte - var certDER *pem.Block - var certParsed *x509.Certificate - var key crypto.PrivateKey - - // Load the certificate. What we really want was for - // crypto.tls.X509KeyPair or crypto.tls.LoadX509KeyPair to - // accept keyFile=nil, but they don't, so we have to load the - // certificate manually. - log.Printf("loading certificate %s", certFile) - if certPEM, err = ioutil.ReadFile(certFile); err != nil { - return err - } - for { - if certDER, certPEM = pem.Decode(certPEM); certDER == nil { - break - } - if certDER.Type != "CERTIFICATE" { - return fmt.Errorf("%s: unexpected type %s", certFile, certDER.Type) - } - cert.Certificate = append(cert.Certificate, certDER.Bytes) - } - if len(cert.Certificate) == 0 { - return fmt.Errorf("%s: no certificates found", certFile) - } - // Load the private key. - log.Printf("loading key CKA_LABEL=%s", keyLabel) - if key, err = crypto11.FindKeyPair(nil, []byte(keyLabel)); err != nil { - return err - } - cert.PrivateKey = key - // Check that the key and the certificate match. - if certParsed, err = x509.ParseCertificate(cert.Certificate[0]); err != nil { - return err - } - switch certPubKey := certParsed.PublicKey.(type) { - case *rsa.PublicKey: - if keyPubKey, ok := key.(crypto.Signer).Public().(*rsa.PublicKey); ok { - if certPubKey.E != keyPubKey.E || certPubKey.N.Cmp(keyPubKey.N) != 0 { - return fmt.Errorf("%s: public key does not match CKA_LABEL=%s", certFile, keyLabel) - } - } else { - return fmt.Errorf("%s: key type does not match CKA_LABEL=%s", certFile, keyLabel) - } - default: - return fmt.Errorf("%s: key type not implemented", certFile) - } - // It's all good, update the configuration - config.Certificates = []tls.Certificate{cert} - return nil -} - -func main() { - http.Handle("/", http.FileServer(http.Dir("/usr/share/doc"))) - server := &http.Server{Addr: ":9090", TLSConfig: &tls.Config{}} - if err := useHardwareKey(server.TLSConfig, keyLabel, certFile); err != nil { - log.Fatal(err) - } - log.Printf("starting server on %s", server.Addr) - log.Fatal(server.ListenAndServeTLS("", "")) -} diff --git a/dsa.go b/dsa.go index a9d0b09..8ab3194 100644 --- a/dsa.go +++ b/dsa.go @@ -27,23 +27,25 @@ import ( "io" "math/big" + "github.com/pkg/errors" + pkcs11 "github.com/miekg/pkcs11" ) -// PKCS11PrivateKeyDSA contains a reference to a loaded PKCS#11 DSA private key object. -type PKCS11PrivateKeyDSA struct { - PKCS11PrivateKey +// pkcs11PrivateKeyDSA contains a reference to a loaded PKCS#11 DSA private key object. +type pkcs11PrivateKeyDSA struct { + pkcs11PrivateKey } // Export the public key corresponding to a private DSA key. -func exportDSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) (crypto.PublicKey, error) { +func exportDSAPublicKey(session *pkcs11Session, pubHandle pkcs11.ObjectHandle) (crypto.PublicKey, error) { template := []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_PRIME, nil), pkcs11.NewAttribute(pkcs11.CKA_SUBPRIME, nil), pkcs11.NewAttribute(pkcs11.CKA_BASE, nil), pkcs11.NewAttribute(pkcs11.CKA_VALUE, nil), } - exported, err := session.Ctx.GetAttributeValue(session.Handle, pubHandle, template) + exported, err := session.ctx.GetAttributeValue(session.handle, pubHandle, template) if err != nil { return nil, err } @@ -63,90 +65,112 @@ func exportDSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) ( return &result, nil } -// GenerateDSAKeyPair creates a DSA private key on the default slot -// -// The key will have a random label and ID. -func GenerateDSAKeyPair(params *dsa.Parameters) (*PKCS11PrivateKeyDSA, error) { - return GenerateDSAKeyPairOnSlot(instance.slot, nil, nil, params) +func notNilBytes(obj []byte, name string) error { + if obj == nil { + return errors.Errorf("%s cannot be nil", name) + } + return nil } -// GenerateDSAKeyPairOnSlot creates a DSA private key on a specified slot -// -// Either or both label and/or id can be nil, in which case random values will be generated. -func GenerateDSAKeyPairOnSlot(slot uint, id []byte, label []byte, params *dsa.Parameters) (*PKCS11PrivateKeyDSA, error) { - var k *PKCS11PrivateKeyDSA - var err error - if err = ensureSessions(instance, slot); err != nil { +// GenerateDSAKeyPair creates a DSA key pair on the token. The id parameter is used to +// set CKA_ID and must be non-nil. +func (c *Context) GenerateDSAKeyPair(id []byte, params *dsa.Parameters) (Signer, error) { + if c.closed.Get() { + return nil, errClosed + } + + if err := notNilBytes(id, "id"); err != nil { return nil, err } - err = withSession(slot, func(session *PKCS11Session) error { - k, err = GenerateDSAKeyPairOnSession(session, slot, id, label, params) - return err - }) - return k, err + + return c.generateDSAKeyPair(id, nil, params) } -// GenerateDSAKeyPairOnSession creates a DSA private key using a specified session -// -// Either or both label and/or id can be nil, in which case random values will be generated. -func GenerateDSAKeyPairOnSession(session *PKCS11Session, slot uint, id []byte, label []byte, params *dsa.Parameters) (*PKCS11PrivateKeyDSA, error) { - var err error - var pub crypto.PublicKey - - if label == nil { - if label, err = generateKeyLabel(); err != nil { - return nil, err - } - } - if id == nil { - if id, err = generateKeyLabel(); err != nil { - return nil, err - } +// GenerateDSAKeyPairWithLabel creates a DSA key pair on the token. The id and label parameters are used to +// set CKA_ID and CKA_LABEL respectively and must be non-nil. +func (c *Context) GenerateDSAKeyPairWithLabel(id, label []byte, params *dsa.Parameters) (Signer, error) { + if c.closed.Get() { + return nil, errClosed } - p := params.P.Bytes() - q := params.Q.Bytes() - g := params.G.Bytes() - publicKeyTemplate := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_DSA), - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), - pkcs11.NewAttribute(pkcs11.CKA_PRIME, p), - pkcs11.NewAttribute(pkcs11.CKA_SUBPRIME, q), - pkcs11.NewAttribute(pkcs11.CKA_BASE, g), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), - pkcs11.NewAttribute(pkcs11.CKA_ID, id), - } - privateKeyTemplate := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), - pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), - pkcs11.NewAttribute(pkcs11.CKA_ID, id), - } - mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_DSA_KEY_PAIR_GEN, nil)} - pubHandle, privHandle, err := session.Ctx.GenerateKeyPair(session.Handle, - mech, - publicKeyTemplate, - privateKeyTemplate) - if err != nil { + + if err := notNilBytes(id, "id"); err != nil { return nil, err } - if pub, err = exportDSAPublicKey(session, pubHandle); err != nil { + if err := notNilBytes(label, "label"); err != nil { return nil, err } - priv := PKCS11PrivateKeyDSA{PKCS11PrivateKey{PKCS11Object{privHandle, slot}, pub}} - return &priv, nil + + return c.generateDSAKeyPair(id, label, params) +} + +// generateDSAKeyPair creates a DSA private key on the token. +func (c *Context) generateDSAKeyPair(id, label []byte, params *dsa.Parameters) (k *pkcs11PrivateKeyDSA, err error) { + err = c.withSession(func(session *pkcs11Session) error { + p := params.P.Bytes() + q := params.Q.Bytes() + g := params.G.Bytes() + + publicKeyTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_DSA), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), + pkcs11.NewAttribute(pkcs11.CKA_PRIME, p), + pkcs11.NewAttribute(pkcs11.CKA_SUBPRIME, q), + pkcs11.NewAttribute(pkcs11.CKA_BASE, g), + } + + privateKeyTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), + } + + if id != nil { + publicKeyTemplate = append(publicKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + privateKeyTemplate = append(privateKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + + if label != nil { + publicKeyTemplate = append(publicKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + privateKeyTemplate = append(privateKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_DSA_KEY_PAIR_GEN, nil)} + pubHandle, privHandle, err := session.ctx.GenerateKeyPair(session.handle, + mech, + publicKeyTemplate, + privateKeyTemplate) + if err != nil { + return err + } + pub, err := exportDSAPublicKey(session, pubHandle) + if err != nil { + return err + } + k = &pkcs11PrivateKeyDSA{ + pkcs11PrivateKey: pkcs11PrivateKey{ + pkcs11Object: pkcs11Object{ + handle: privHandle, + context: c, + }, + pubKeyHandle: pubHandle, + pubKey: pub, + }} + return nil + + }) + return } // Sign signs a message using a DSA key. // -// This completes the implemention of crypto.Signer for PKCS11PrivateKeyDSA. +// This completes the implemention of crypto.Signer for pkcs11PrivateKeyDSA. // // PKCS#11 expects to pick its own random data for signatures, so the rand argument is ignored. // // The return value is a DER-encoded byteblock. -func (signer *PKCS11PrivateKeyDSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { - return dsaGeneric(signer.Slot, signer.Handle, pkcs11.CKM_DSA, digest) +func (signer *pkcs11PrivateKeyDSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + return signer.context.dsaGeneric(signer.handle, pkcs11.CKM_DSA, digest) } diff --git a/dsa_test.go b/dsa_test.go index 23764d9..f53937f 100644 --- a/dsa_test.go +++ b/dsa_test.go @@ -28,10 +28,11 @@ import ( _ "crypto/sha1" _ "crypto/sha256" _ "crypto/sha512" - "github.com/stretchr/testify/require" "io" "math/big" "testing" + + "github.com/stretchr/testify/require" ) // Use pre-cooked groups, making new ones is too slow and doesn't test @@ -100,39 +101,32 @@ func TestNativeDSA(t *testing.T) { } func TestHardDSA(t *testing.T) { - var key *PKCS11PrivateKeyDSA - var key2, key3 crypto.PrivateKey - var id, label []byte - _, err := ConfigureFromFile("config") + ctx, err := ConfigureFromFile("config") require.NoError(t, err) - for psize, params := range dsaSizes { - if key, err = GenerateDSAKeyPair(params); err != nil { - t.Errorf("crypto11.GenerateDSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateDSAKeyPair: returned nil but no error") - return - } - testDsaSigning(t, key, psize, "hard1") - // Get a fresh handle to the key - if id, label, err = key.Identify(); err != nil { - t.Errorf("crypto11.dsa.PKCS11PrivateKeyDSA.Identify: %v", err) - return - } - if key2, err = FindKeyPair(id, nil); err != nil { - t.Errorf("crypto11.dsa.FindDSAKeyPair by id: %v", err) - return - } - testDsaSigning(t, key2.(*PKCS11PrivateKeyDSA), psize, "hard2") - if key3, err = FindKeyPair(nil, label); err != nil { - t.Errorf("crypto11.dsa.FindKeyPair by label: %v", err) - return - } - testDsaSigning(t, key3.(crypto.Signer), psize, "hard3") + defer func() { + err = ctx.Close() + require.NoError(t, err) + }() + + for pSize, params := range dsaSizes { + + id := randomBytes() + label := randomBytes() + + key, err := ctx.GenerateDSAKeyPairWithLabel(id, label, params) + require.NoError(t, err) + require.NotNil(t, key) + testDsaSigning(t, key, pSize, "hard1") + + key2, err := ctx.FindKeyPair(id, nil) + require.NoError(t, err) + testDsaSigning(t, key2.(*pkcs11PrivateKeyDSA), pSize, "hard2") + + key3, err := ctx.FindKeyPair(nil, label) + require.NoError(t, err) + testDsaSigning(t, key3.(crypto.Signer), pSize, "hard3") } - require.NoError(t, Close()) } func testDsaSigning(t *testing.T, key crypto.Signer, psize dsa.ParameterSizes, what string) { @@ -144,27 +138,46 @@ func testDsaSigning(t *testing.T, key crypto.Signer, psize dsa.ParameterSizes, w } func testDsaSigningWithHash(t *testing.T, key crypto.Signer, hashFunction crypto.Hash, psize dsa.ParameterSizes, what string) { - var err error - var sigDER []byte - var sig dsaSignature plaintext := []byte("sign me with DSA") h := hashFunction.New() - h.Write(plaintext) + _, err := h.Write(plaintext) + require.NoError(t, err) + plaintextHash := h.Sum([]byte{}) // weird API // crypto.dsa.Sign doesn't truncate the hash! qbytes := (dsaSizes[psize].Q.BitLen() + 7) / 8 plaintextHash = plaintextHash[:qbytes] - if sigDER, err = key.Sign(rand.Reader, plaintextHash, hashFunction); err != nil { - t.Errorf("DSA %s Sign (hash %v): %v", what, hashFunction, err) - return - } - if err = sig.unmarshalDER(sigDER); err != nil { - t.Errorf("DSA %s unmarshalDER (hash %v): %v", what, hashFunction, err) - return - } + + sigDER, err := key.Sign(rand.Reader, plaintextHash, hashFunction) + require.NoError(t, err) + + var sig dsaSignature + err = sig.unmarshalDER(sigDER) + require.NoError(t, err) + dsaPubkey := key.Public().(crypto.PublicKey).(*dsa.PublicKey) if !dsa.Verify(dsaPubkey, plaintextHash, sig.R, sig.S) { t.Errorf("DSA %s Verify failed (psize %d hash %v)", what, psize, hashFunction) } } + +func TestDsaRequiredArgs(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + _, err = ctx.GenerateDSAKeyPair(nil, dsaSizes[dsa.L2048N224]) + require.Error(t, err) + + val := randomBytes() + + _, err = ctx.GenerateDSAKeyPairWithLabel(nil, val, dsaSizes[dsa.L2048N224]) + require.Error(t, err) + + _, err = ctx.GenerateDSAKeyPairWithLabel(val, nil, dsaSizes[dsa.L2048N224]) + require.Error(t, err) +} diff --git a/ecdsa.go b/ecdsa.go index 4f406ca..975bc4a 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -27,26 +27,22 @@ import ( "crypto/ecdsa" "crypto/elliptic" "encoding/asn1" - "errors" "io" "math/big" - pkcs11 "github.com/miekg/pkcs11" + "github.com/miekg/pkcs11" + "github.com/pkg/errors" ) -// ErrUnsupportedEllipticCurve is returned when an elliptic curve +// errUnsupportedEllipticCurve is returned when an elliptic curve // unsupported by crypto11 is specified. Note that the error behavior // for an elliptic curve unsupported by the underlying PKCS#11 // implementation will be different. -var ErrUnsupportedEllipticCurve = errors.New("crypto11/ecdsa: unsupported elliptic curve") +var errUnsupportedEllipticCurve = errors.New("unsupported elliptic curve") -// ErrMalformedPoint is returned when crypto.elliptic.Unmarshal cannot -// decode a point. -var ErrMalformedPoint = errors.New("crypto11/ecdsa: malformed elliptic curve point") - -// PKCS11PrivateKeyECDSA contains a reference to a loaded PKCS#11 ECDSA private key object. -type PKCS11PrivateKeyECDSA struct { - PKCS11PrivateKey +// pkcs11PrivateKeyECDSA contains a reference to a loaded PKCS#11 ECDSA private key object. +type pkcs11PrivateKeyECDSA struct { + pkcs11PrivateKey } // Information about an Elliptic Curve @@ -145,57 +141,44 @@ func marshalEcParams(c elliptic.Curve) ([]byte, error) { return ci.oid, nil } // TODO use ANSI X9.62 ECParameters representation instead - return nil, ErrUnsupportedEllipticCurve + return nil, errUnsupportedEllipticCurve } func unmarshalEcParams(b []byte) (elliptic.Curve, error) { // See if it's a well-known curve for _, ci := range wellKnownCurves { - if bytes.Compare(b, ci.oid) == 0 { + if bytes.Equal(b, ci.oid) { if ci.curve != nil { return ci.curve, nil } - return nil, ErrUnsupportedEllipticCurve + return nil, errUnsupportedEllipticCurve } } // TODO try ANSI X9.62 ECParameters representation - return nil, ErrUnsupportedEllipticCurve + return nil, errUnsupportedEllipticCurve } -func unmarshalEcPoint(b []byte, c elliptic.Curve) (x *big.Int, y *big.Int, err error) { - // Decoding an octet string in isolation seems to be too hard - // with encoding.asn1, so we do it manually. Look away now. - if b[0] != 4 { - return nil, nil, ErrMalformedDER - } - var l, r int - if b[1] < 128 { - l = int(b[1]) - r = 2 - } else { - ll := int(b[1] & 127) - if ll > 2 { // unreasonably long - return nil, nil, ErrMalformedDER - } - l = 0 - for i := int(0); i < ll; i++ { - l = 256*l + int(b[2+i]) - } - r = ll + 2 +func unmarshalEcPoint(b []byte, c elliptic.Curve) (*big.Int, *big.Int, error) { + var pointBytes []byte + extra, err := asn1.Unmarshal(b, &pointBytes) + if err != nil { + return nil, nil, errors.WithMessage(err, "elliptic curve point is invalid ASN.1") } - if r+l > len(b) { - return nil, nil, ErrMalformedDER + + if len(extra) > 0 { + // We weren't expecting extra data + return nil, nil, errors.New("unexpected data found when parsing elliptic curve point") } - pointBytes := b[r:] - x, y = elliptic.Unmarshal(c, pointBytes) + + x, y := elliptic.Unmarshal(c, pointBytes) if x == nil || y == nil { - err = ErrMalformedPoint + return nil, nil, errors.New("failed to parse elliptic curve point") } - return + return x, y, nil } // Export the public key corresponding to a private ECDSA key. -func exportECDSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) (crypto.PublicKey, error) { +func exportECDSAPublicKey(session *pkcs11Session, pubHandle pkcs11.ObjectHandle) (crypto.PublicKey, error) { var err error var attributes []*pkcs11.Attribute var pub ecdsa.PublicKey @@ -203,7 +186,7 @@ func exportECDSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) pkcs11.NewAttribute(pkcs11.CKA_ECDSA_PARAMS, nil), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil), } - if attributes, err = session.Ctx.GetAttributeValue(session.Handle, pubHandle, template); err != nil { + if attributes, err = session.ctx.GetAttributeValue(session.handle, pubHandle, template); err != nil { return nil, err } if pub.Curve, err = unmarshalEcParams(attributes[0].Value); err != nil { @@ -215,98 +198,105 @@ func exportECDSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) return &pub, nil } -// GenerateECDSAKeyPair creates an ECDSA private key using curve c. -// -// The key will have a random label and ID. -// -// Only a limited set of named elliptic curves are supported. The +// GenerateECDSAKeyPair creates a ECDSA key pair on the token using curve c. The id parameter is used to +// set CKA_ID and must be non-nil. Only a limited set of named elliptic curves are supported. The // underlying PKCS#11 implementation may impose further restrictions. -func GenerateECDSAKeyPair(c elliptic.Curve) (*PKCS11PrivateKeyECDSA, error) { - return GenerateECDSAKeyPairOnSlot(instance.slot, nil, nil, c) -} +func (c *Context) GenerateECDSAKeyPair(id []byte, curve elliptic.Curve) (Signer, error) { + if c.closed.Get() { + return nil, errClosed + } -// GenerateECDSAKeyPairOnSlot creates an ECDSA private key using curve c, on a specified slot. -// -// label and/or id can be nil, in which case random values will be generated. -// -// Only a limited set of named elliptic curves are supported. The -// underlying PKCS#11 implementation may impose further restrictions. -func GenerateECDSAKeyPairOnSlot(slot uint, id []byte, label []byte, c elliptic.Curve) (*PKCS11PrivateKeyECDSA, error) { - var k *PKCS11PrivateKeyECDSA - var err error - if err = ensureSessions(instance, slot); err != nil { + if err := notNilBytes(id, "id"); err != nil { return nil, err } - err = withSession(slot, func(session *PKCS11Session) error { - k, err = GenerateECDSAKeyPairOnSession(session, slot, id, label, c) - return err - }) - return k, err + + return c.generateECDSAKeyPair(id, nil, curve) } -// GenerateECDSAKeyPairOnSession creates an ECDSA private key using curve c, using a specified session. -// -// label and/or id can be nil, in which case random values will be generated. -// -// Only a limited set of named elliptic curves are supported. The +// GenerateECDSAKeyPairWithLabel creates a ECDSA key pair on the token using curve c. The id and label parameters are used to +// set CKA_ID and CKA_LABEL respectively and must be non-nil. Only a limited set of named elliptic curves are supported. The // underlying PKCS#11 implementation may impose further restrictions. -func GenerateECDSAKeyPairOnSession(session *PKCS11Session, slot uint, id []byte, label []byte, c elliptic.Curve) (*PKCS11PrivateKeyECDSA, error) { - var err error - var parameters []byte - var pub crypto.PublicKey - - if label == nil { - if label, err = generateKeyLabel(); err != nil { - return nil, err - } +func (c *Context) GenerateECDSAKeyPairWithLabel(id, label []byte, curve elliptic.Curve) (Signer, error) { + if c.closed.Get() { + return nil, errClosed } - if id == nil { - if id, err = generateKeyLabel(); err != nil { - return nil, err - } - } - if parameters, err = marshalEcParams(c); err != nil { - return nil, err - } - publicKeyTemplate := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_ECDSA), - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), - pkcs11.NewAttribute(pkcs11.CKA_ID, id), - pkcs11.NewAttribute(pkcs11.CKA_ECDSA_PARAMS, parameters), - } - privateKeyTemplate := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), - pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), - pkcs11.NewAttribute(pkcs11.CKA_ID, id), - } - mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA_KEY_PAIR_GEN, nil)} - pubHandle, privHandle, err := session.Ctx.GenerateKeyPair(session.Handle, - mech, - publicKeyTemplate, - privateKeyTemplate) - if err != nil { + + if err := notNilBytes(id, "id"); err != nil { return nil, err } - if pub, err = exportECDSAPublicKey(session, pubHandle); err != nil { + if err := notNilBytes(label, "label"); err != nil { return nil, err } - priv := PKCS11PrivateKeyECDSA{PKCS11PrivateKey{PKCS11Object{privHandle, slot}, pub}} - return &priv, nil + + return c.generateECDSAKeyPair(id, label, curve) +} + +// generateECDSAKeyPair generates a key pair on the token. +func (c *Context) generateECDSAKeyPair(id, label []byte, curve elliptic.Curve) (k *pkcs11PrivateKeyECDSA, err error) { + err = c.withSession(func(session *pkcs11Session) error { + + parameters, err := marshalEcParams(curve) + if err != nil { + return err + } + publicKeyTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_ECDSA), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), + pkcs11.NewAttribute(pkcs11.CKA_ECDSA_PARAMS, parameters), + } + privateKeyTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), + } + + if id != nil { + publicKeyTemplate = append(publicKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + privateKeyTemplate = append(privateKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + + if label != nil { + publicKeyTemplate = append(publicKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + privateKeyTemplate = append(privateKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA_KEY_PAIR_GEN, nil)} + pubHandle, privHandle, err := session.ctx.GenerateKeyPair(session.handle, + mech, + publicKeyTemplate, + privateKeyTemplate) + if err != nil { + return err + } + + pub, err := exportECDSAPublicKey(session, pubHandle) + if err != nil { + return err + } + k = &pkcs11PrivateKeyECDSA{ + pkcs11PrivateKey: pkcs11PrivateKey{ + pkcs11Object: pkcs11Object{ + handle: privHandle, + context: c, + }, + pubKeyHandle: pubHandle, + pubKey: pub, + }} + return nil + }) + return } // Sign signs a message using an ECDSA key. // -// This completes the implemention of crypto.Signer for PKCS11PrivateKeyECDSA. +// This completes the implemention of crypto.Signer for pkcs11PrivateKeyECDSA. // // PKCS#11 expects to pick its own random data where necessary for signatures, so the rand argument is ignored. // // The return value is a DER-encoded byteblock. -func (signer *PKCS11PrivateKeyECDSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - return dsaGeneric(signer.Slot, signer.Handle, pkcs11.CKM_ECDSA, digest) +func (signer *pkcs11PrivateKeyECDSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + return signer.context.dsaGeneric(signer.handle, pkcs11.CKM_ECDSA, digest) } diff --git a/ecdsa_test.go b/ecdsa_test.go index 21d4e11..3422fac 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -29,8 +29,9 @@ import ( _ "crypto/sha1" _ "crypto/sha256" _ "crypto/sha512" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) var curves = []elliptic.Curve{ @@ -58,65 +59,76 @@ func TestNativeECDSA(t *testing.T) { } func TestHardECDSA(t *testing.T) { - var key *PKCS11PrivateKeyECDSA - var key2, key3 crypto.PrivateKey - var id, label []byte - _, err := ConfigureFromFile("config") + ctx, err := ConfigureFromFile("config") require.NoError(t, err) + defer func() { + err = ctx.Close() + require.NoError(t, err) + }() + for _, curve := range curves { - if key, err = GenerateECDSAKeyPair(curve); err != nil { - t.Errorf("GenerateECDSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateECDSAKeyPair: returned nil but no error") - return - } + id := randomBytes() + label := randomBytes() + + key, err := ctx.GenerateECDSAKeyPairWithLabel(id, label, curve) + require.NoError(t, err) + require.NotNil(t, key) + testEcdsaSigning(t, key, crypto.SHA1) testEcdsaSigning(t, key, crypto.SHA224) testEcdsaSigning(t, key, crypto.SHA256) testEcdsaSigning(t, key, crypto.SHA384) testEcdsaSigning(t, key, crypto.SHA512) - // Get a fresh handle to the key - if id, label, err = key.Identify(); err != nil { - t.Errorf("crypto11.ecdsa.PKCS11PrivateKeyECDSA.Identify: %v", err) - return - } - if key2, err = FindKeyPair(id, nil); err != nil { - t.Errorf("crypto11.ecdsa.FindECDSAKeyPair by id: %v", err) - return - } - testEcdsaSigning(t, key2.(*PKCS11PrivateKeyECDSA), crypto.SHA256) - if key3, err = FindKeyPair(nil, label); err != nil { - t.Errorf("crypto11.ecdsa.FindKeyPair by label: %v", err) - return - } + + key2, err := ctx.FindKeyPair(id, nil) + require.NoError(t, err) + testEcdsaSigning(t, key2.(*pkcs11PrivateKeyECDSA), crypto.SHA256) + + key3, err := ctx.FindKeyPair(nil, label) + require.NoError(t, err) testEcdsaSigning(t, key3.(crypto.Signer), crypto.SHA384) } - require.NoError(t, Close()) } func testEcdsaSigning(t *testing.T, key crypto.Signer, hashFunction crypto.Hash) { - var err error - var sigDER []byte - var sig dsaSignature plaintext := []byte("sign me with ECDSA") h := hashFunction.New() - h.Write(plaintext) + _, err := h.Write(plaintext) + require.NoError(t, err) plaintextHash := h.Sum([]byte{}) // weird API - if sigDER, err = key.Sign(rand.Reader, plaintextHash, nil); err != nil { - t.Errorf("ECDSA Sign (hash %v): %v", hashFunction, err) - return - } - if err = sig.unmarshalDER(sigDER); err != nil { - t.Errorf("ECDSA unmarshalDER (hash %v): %v", hashFunction, err) - return - } + + sigDER, err := key.Sign(rand.Reader, plaintextHash, nil) + require.NoError(t, err) + + var sig dsaSignature + err = sig.unmarshalDER(sigDER) + require.NoError(t, err) + ecdsaPubkey := key.Public().(crypto.PublicKey).(*ecdsa.PublicKey) if !ecdsa.Verify(ecdsaPubkey, plaintextHash, sig.R, sig.S) { t.Errorf("ECDSA Verify (hash %v): %v", hashFunction, err) } } + +func TestEcdsaRequiredArgs(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + _, err = ctx.GenerateECDSAKeyPair(nil, elliptic.P224()) + require.Error(t, err) + + val := randomBytes() + + _, err = ctx.GenerateECDSAKeyPairWithLabel(nil, val, elliptic.P224()) + require.Error(t, err) + + _, err = ctx.GenerateECDSAKeyPairWithLabel(val, nil, elliptic.P224()) + require.Error(t, err) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..45b2b80 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module github.com/ThalesIgnite/crypto11 + +go 1.12 + +require ( + github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect + github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f + github.com/pkg/errors v0.8.1 + github.com/stretchr/testify v1.3.0 + github.com/vitessio/vitess v2.1.1+incompatible + github.com/youtube/vitess v2.1.1+incompatible // indirect + golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..73b1a93 --- /dev/null +++ b/go.sum @@ -0,0 +1,24 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/miekg/pkcs11 v1.0.2 h1:CIBkOawOtzJNE0B+EpRiUBzuVW7JEQAwdwhSS6YhIeg= +github.com/miekg/pkcs11 v1.0.2/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f h1:eVB9ELsoq5ouItQBr5Tj334bhPJG/MX+m7rTchmzVUQ= +github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/vitessio/vitess v2.1.1+incompatible h1:zE9Moh7xCrFDpniUsYPVoiOWtdTK0TWoHUjwFV7iFCA= +github.com/vitessio/vitess v2.1.1+incompatible/go.mod h1:A11WWLimUfZAYYm8P1I63RryRPP2GdpHRgQcfa++OnQ= +github.com/youtube/vitess v2.1.1+incompatible h1:SE+P7DNX/jw5RHFs5CHRhZQjq402EJFCD33JhzQMdDw= +github.com/youtube/vitess v2.1.1+incompatible/go.mod h1:hpMim5/30F1r+0P8GGtB29d0gWHr0IZ5unS+CG0zMx8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6 h1:FP8hkuE6yUEaJnK7O2eTuejKWwW+Rhfj80dQ2JcKxCU= +golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/hmac.go b/hmac.go index dc710ae..16b7415 100644 --- a/hmac.go +++ b/hmac.go @@ -22,12 +22,10 @@ package crypto11 import ( - "context" "errors" - "fmt" - "github.com/miekg/pkcs11" - "github.com/youtube/vitess/go/pools" "hash" + + "github.com/miekg/pkcs11" ) const ( @@ -58,10 +56,10 @@ const ( type hmacImplementation struct { // PKCS#11 session to use - session *PKCS11Session + session *pkcs11Session // Signing key - key *PKCS11SecretKey + key *SecretKey // Hash size size int @@ -109,8 +107,8 @@ var hmacInfos = map[int]*hmacInfo{ pkcs11.CKM_RIPEMD160_HMAC_GENERAL: {20, 64, true}, } -// ErrHmacClosed is called if an HMAC is updated after it has finished. -var ErrHmacClosed = errors.New("already called Sum()") +// errHmacClosed is called if an HMAC is updated after it has finished. +var errHmacClosed = errors.New("already called Sum()") // NewHMAC returns a new HMAC hash using the given PKCS#11 mechanism // and key. @@ -122,9 +120,8 @@ var ErrHmacClosed = errors.New("already called Sum()") // // The Reset() method is not implemented. // After Sum() is called no new data may be added. -func (key *PKCS11SecretKey) NewHMAC(mech int, length int) (h hash.Hash, err error) { - var hi hmacImplementation - hi = hmacImplementation{ +func (key *SecretKey) NewHMAC(mech int, length int) (hash.Hash, error) { + hi := hmacImplementation{ key: key, } var params []byte @@ -140,36 +137,24 @@ func (key *PKCS11SecretKey) NewHMAC(mech int, length int) (h hash.Hash, err erro hi.size = length } hi.mechDescription = []*pkcs11.Mechanism{pkcs11.NewMechanism(uint(mech), params)} - if err = hi.initialize(); err != nil { - return + if err := hi.initialize(); err != nil { + return nil, err } - h = &hi - return + return &hi, nil } func (hi *hmacImplementation) initialize() (err error) { - // TODO refactor with newBlockModeCloser - sessionPool := pool.Get(hi.key.Slot) - if sessionPool == nil { - err = fmt.Errorf("crypto11: no session for slot %d", hi.key.Slot) - return - } - ctx := context.Background() - if instance.cfg.PoolWaitTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), instance.cfg.PoolWaitTimeout) - defer cancel() + session, err := hi.key.context.getSession() + if err != nil { + return err } - var session pools.Resource - if session, err = sessionPool.Get(ctx); err != nil { - return - } - hi.session = session.(*PKCS11Session) + + hi.session = session hi.cleanup = func() { - sessionPool.Put(session) + hi.key.context.pool.Put(session) hi.session = nil } - if err = hi.session.Ctx.SignInit(hi.session.Handle, hi.mechDescription, hi.key.Handle); err != nil { + if err = hi.session.ctx.SignInit(hi.session.handle, hi.mechDescription, hi.key.handle); err != nil { hi.cleanup() return } @@ -181,11 +166,11 @@ func (hi *hmacImplementation) initialize() (err error) { func (hi *hmacImplementation) Write(p []byte) (n int, err error) { if hi.result != nil { if len(p) > 0 { - err = ErrHmacClosed + err = errHmacClosed } return } - if err = hi.session.Ctx.SignUpdate(hi.session.Handle, p); err != nil { + if err = hi.session.ctx.SignUpdate(hi.session.handle, p); err != nil { return } hi.updates++ @@ -199,11 +184,11 @@ func (hi *hmacImplementation) Sum(b []byte) []byte { if hi.updates == 0 { // http://docs.oasis-open.org/pkcs11/pkcs11-base/v2.40/os/pkcs11-base-v2.40-os.html#_Toc322855304 // We must ensure that C_SignUpdate is called _at least once_. - if err = hi.session.Ctx.SignUpdate(hi.session.Handle, []byte{}); err != nil { + if err = hi.session.ctx.SignUpdate(hi.session.handle, []byte{}); err != nil { panic(err) } } - hi.result, err = hi.session.Ctx.SignFinal(hi.session.Handle) + hi.result, err = hi.session.ctx.SignFinal(hi.session.handle) hi.cleanup() if err != nil { panic(err) diff --git a/hmac_test.go b/hmac_test.go index 4757065..28354f2 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -22,165 +22,134 @@ package crypto11 import ( - "bytes" + "testing" + "github.com/miekg/pkcs11" "github.com/stretchr/testify/require" - "hash" - "testing" ) func TestHmac(t *testing.T) { - _, err := ConfigureFromFile("config") + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + err = ctx.Close() + require.NoError(t, err) + }() + + info, err := ctx.ctx.GetInfo() require.NoError(t, err) - var info pkcs11.Info - if info, err = instance.ctx.GetInfo(); err != nil { - t.Errorf("GetInfo: %v", err) - return - } if info.ManufacturerID == "SoftHSM" { t.Skipf("HMAC not implemented on SoftHSM") } t.Run("HMACSHA1", func(t *testing.T) { - testHmac(t, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC, 0, 20, false) + testHmac(t, ctx, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC, 0, 20, false) }) t.Run("HMACSHA1General", func(t *testing.T) { - testHmac(t, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC_GENERAL, 10, 10, true) + testHmac(t, ctx, pkcs11.CKK_SHA_1_HMAC, pkcs11.CKM_SHA_1_HMAC_GENERAL, 10, 10, true) }) t.Run("HMACSHA256", func(t *testing.T) { - testHmac(t, pkcs11.CKK_SHA256_HMAC, pkcs11.CKM_SHA256_HMAC, 0, 32, false) + testHmac(t, ctx, pkcs11.CKK_SHA256_HMAC, pkcs11.CKM_SHA256_HMAC, 0, 32, false) }) - require.NoError(t, Close()) + } -func testHmac(t *testing.T, keytype int, mech int, length int, xlength int, full bool) { - var err error - var key *PKCS11SecretKey - t.Run("Generate", func(t *testing.T) { - if key, err = GenerateSecretKey(256, Ciphers[keytype]); err != nil { - t.Errorf("crypto11.GenerateSecretKey: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.GenerateSecretKey: returned nil but no error") - return - } - }) - if key == nil { - return - } +func testHmac(t *testing.T, ctx *Context, keytype int, mech int, length int, xlength int, full bool) { + + id := randomBytes() + key, err := ctx.GenerateSecretKey(id, 256, Ciphers[keytype]) + require.NoError(t, err) + require.NotNil(t, key) + t.Run("Short", func(t *testing.T) { input := []byte("a short string") - var h1, h2 hash.Hash - if h1, err = key.NewHMAC(mech, length); err != nil { - t.Errorf("key.NewHMAC: %v", err) - return - } - if n, err := h1.Write(input); err != nil || n != len(input) { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + h1, err := key.NewHMAC(mech, length) + require.NoError(t, err) + + n, err := h1.Write(input) + require.NoError(t, err) + require.Equal(t, len(input), n) + r1 := h1.Sum([]byte{}) - if h2, err = key.NewHMAC(mech, length); err != nil { - t.Errorf("key.NewHMAC: %v", err) - return - } - if n, err := h2.Write(input); err != nil || n != len(input) { - t.Errorf("h2.Write: %v/%d", err, n) - return - } + h2, err := key.NewHMAC(mech, length) + require.NoError(t, err) + + n, err = h2.Write(input) + require.NoError(t, err) + require.Equal(t, len(input), n) + r2 := h2.Sum([]byte{}) - if bytes.Compare(r1, r2) != 0 { - t.Errorf("h1/h2 inconsistent") - return - } - if len(r1) != xlength { - t.Errorf("r1 wrong length (want %v got %v)", xlength, len(r1)) - return - } + + require.Equal(t, r1, r2) + require.Len(t, r1, xlength) }) if full { // Independent of hash, only do these once t.Run("Empty", func(t *testing.T) { // Must be able to MAC empty inputs without panicing - var h1 hash.Hash - if h1, err = key.NewHMAC(mech, length); err != nil { - t.Errorf("key.NewHMAC: %v", err) - return - } + h1, err := key.NewHMAC(mech, length) + require.NoError(t, err) h1.Sum([]byte{}) }) t.Run("MultiSum", func(t *testing.T) { input := []byte("a different short string") - var h1 hash.Hash - if h1, err = key.NewHMAC(mech, length); err != nil { - t.Errorf("key.NewHMAC: %v", err) - return - } - if n, err := h1.Write(input); err != nil || n != len(input) { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + + h1, err := key.NewHMAC(mech, length) + require.NoError(t, err) + + n, err := h1.Write(input) + require.NoError(t, err) + require.Equal(t, len(input), n) + r1 := h1.Sum([]byte{}) r2 := h1.Sum([]byte{}) - if bytes.Compare(r1, r2) != 0 { - t.Errorf("r1/r2 inconsistent") - return - } + require.Equal(t, r1, r2) + // Can't add more after Sum() - if n, err := h1.Write(input); err != ErrHmacClosed { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + _, err = h1.Write(input) + require.Equal(t, errHmacClosed, err) + // 0-length is special - if n, err := h1.Write([]byte{}); err != nil || n != 0 { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + n, err = h1.Write([]byte{}) + require.NoError(t, err) + require.Zero(t, n) }) t.Run("Reset", func(t *testing.T) { - var h1 hash.Hash - if h1, err = key.NewHMAC(mech, length); err != nil { - t.Errorf("key.NewHMAC: %v", err) - return - } - if n, err := h1.Write([]byte{1}); err != nil || n != 1 { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + + h1, err := key.NewHMAC(mech, length) + require.NoError(t, err) + + n, err := h1.Write([]byte{1}) + require.NoError(t, err) + require.Equal(t, 1, n) + r1 := h1.Sum([]byte{}) h1.Reset() - if n, err := h1.Write([]byte{2}); err != nil || n != 1 { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + + n, err = h1.Write([]byte{2}) + require.NoError(t, err) + require.Equal(t, 1, n) + r2 := h1.Sum([]byte{}) h1.Reset() - if n, err := h1.Write([]byte{1}); err != nil || n != 1 { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + + n, err = h1.Write([]byte{1}) + require.NoError(t, err) + require.Equal(t, 1, n) + r3 := h1.Sum([]byte{}) - if bytes.Compare(r1, r3) != 0 { - t.Errorf("r1/r3 inconsistent") - return - } - if bytes.Compare(r1, r2) == 0 { - t.Errorf("r1/r2 unexpectedly equal") - return - } + require.Equal(t, r1, r3) + require.NotEqual(t, r1, r2) }) t.Run("ResetFast", func(t *testing.T) { // Reset() immediately after creation should be safe - var h1 hash.Hash - if h1, err = key.NewHMAC(mech, length); err != nil { - t.Errorf("key.NewHMAC: %v", err) - return - } + + h1, err := key.NewHMAC(mech, length) + require.NoError(t, err) h1.Reset() - if n, err := h1.Write([]byte{2}); err != nil || n != 1 { - t.Errorf("h1.Write: %v/%d", err, n) - return - } + n, err := h1.Write([]byte{2}) + require.NoError(t, err) + require.Equal(t, 1, n) h1.Sum([]byte{}) }) } diff --git a/keys.go b/keys.go index f71113e..9320696 100644 --- a/keys.go +++ b/keys.go @@ -24,36 +24,25 @@ package crypto11 import ( "crypto" - pkcs11 "github.com/miekg/pkcs11" + "github.com/miekg/pkcs11" + "github.com/pkg/errors" ) -// Identify returns the ID and label for a PKCS#11 object. -// -// Either of these values may be used to retrieve the key for later use. -func (object *PKCS11Object) Identify() (id []byte, label []byte, err error) { - a := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_ID, nil), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, nil), - } - if err = withSession(object.Slot, func(session *PKCS11Session) error { - a, err = instance.ctx.GetAttributeValue(session.Handle, object.Handle, a) - return err - }); err != nil { - return nil, nil, err - } - return a[0].Value, a[1].Value, nil -} - // Find a key object. For asymmetric keys this only finds one half so -// callers will call it twice. -func findKey(session *PKCS11Session, id []byte, label []byte, keyclass uint, keytype uint) (obj pkcs11.ObjectHandle, err error) { +// callers will call it twice. Returns nil if the key does not exist on the token. +func findKey(session *pkcs11Session, id []byte, label []byte, keyclass *uint, keytype *uint) (obj *pkcs11.ObjectHandle, err error) { var handles []pkcs11.ObjectHandle var template []*pkcs11.Attribute - if keyclass != ^uint(0) { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyclass)) + + if id == nil && label == nil { + return nil, errors.New("id and label cannot both be nil") + } + + if keyclass != nil { + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_CLASS, *keyclass)) } - if keytype != ^uint(0) { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, keytype)) + if keytype != nil { + template = append(template, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, *keytype)) } if id != nil { template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) @@ -61,140 +50,165 @@ func findKey(session *PKCS11Session, id []byte, label []byte, keyclass uint, key if label != nil { template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) } - if err = session.Ctx.FindObjectsInit(session.Handle, template); err != nil { - return 0, err + if err = session.ctx.FindObjectsInit(session.handle, template); err != nil { + return nil, err } defer func() { - finalErr := session.Ctx.FindObjectsFinal(session.Handle) + finalErr := session.ctx.FindObjectsFinal(session.handle) if err == nil { err = finalErr } }() - if handles, _, err = session.Ctx.FindObjects(session.Handle, 1); err != nil { - return 0, err + if handles, _, err = session.ctx.FindObjects(session.handle, 1); err != nil { + return nil, err } if len(handles) == 0 { - return 0, ErrKeyNotFound + return nil, nil } - return handles[0], nil + return &handles[0], nil } -// FindKeyPair retrieves a previously created asymmetric key. +// FindKeyPair retrieves a previously created asymmetric key pair, or nil if it cannot be found. // -// Either (but not both) of id and label may be nil, in which case they are ignored. -func FindKeyPair(id []byte, label []byte) (crypto.PrivateKey, error) { - return FindKeyPairOnSlot(instance.slot, id, label) -} +// At least one of id and label must be specified. If the private key is found, but the public key is +// not, an error is returned because we cannot implement crypto.Signer without the public key. +func (c *Context) FindKeyPair(id []byte, label []byte) (Signer, error) { -// FindKeyPairOnSlot retrieves a previously created asymmetric key, using a specified slot. -// -// Either (but not both) of id and label may be nil, in which case they are ignored. -func FindKeyPairOnSlot(slot uint, id []byte, label []byte) (crypto.PrivateKey, error) { - var err error - var k crypto.PrivateKey - if err = ensureSessions(instance, slot); err != nil { - return nil, err + if c.closed.Get() { + return nil, errClosed } - err = withSession(slot, func(session *PKCS11Session) error { - k, err = FindKeyPairOnSession(session, slot, id, label) - return err - }) - return k, err -} -// FindKeyPairOnSession retrieves a previously created asymmetric key, using a specified session. -// -// Either (but not both) of id and label may be nil, in which case they are ignored. -func FindKeyPairOnSession(session *PKCS11Session, slot uint, id []byte, label []byte) (crypto.PrivateKey, error) { - var err error - var privHandle, pubHandle pkcs11.ObjectHandle - var pub crypto.PublicKey + var k Signer - if privHandle, err = findKey(session, id, label, pkcs11.CKO_PRIVATE_KEY, ^uint(0)); err != nil { - return nil, err - } - attributes := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, 0), - } - if attributes, err = session.Ctx.GetAttributeValue(session.Handle, privHandle, attributes); err != nil { - return nil, err - } - keyType := bytesToUlong(attributes[0].Value) - if pubHandle, err = findKey(session, id, label, pkcs11.CKO_PUBLIC_KEY, keyType); err != nil { - return nil, err - } - switch keyType { - case pkcs11.CKK_DSA: - if pub, err = exportDSAPublicKey(session, pubHandle); err != nil { - return nil, err + err := c.withSession(func(session *pkcs11Session) error { + + var pub crypto.PublicKey + + privHandle, err := findKey(session, id, label, uintPtr(pkcs11.CKO_PRIVATE_KEY), nil) + if err != nil { + return err + } + if privHandle == nil { + // Cannot continue, no key found + return nil } - return &PKCS11PrivateKeyDSA{PKCS11PrivateKey{PKCS11Object{privHandle, slot}, pub}}, nil - case pkcs11.CKK_RSA: - if pub, err = exportRSAPublicKey(session, pubHandle); err != nil { - return nil, err + + attributes := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, 0), } - return &PKCS11PrivateKeyRSA{PKCS11PrivateKey{PKCS11Object{privHandle, slot}, pub}}, nil - case pkcs11.CKK_ECDSA: - if pub, err = exportECDSAPublicKey(session, pubHandle); err != nil { - return nil, err + if attributes, err = session.ctx.GetAttributeValue(session.handle, *privHandle, attributes); err != nil { + return err } - return &PKCS11PrivateKeyECDSA{PKCS11PrivateKey{PKCS11Object{privHandle, slot}, pub}}, nil - default: - return nil, ErrUnsupportedKeyType - } + keyType := bytesToUlong(attributes[0].Value) + + pubHandle, err := findKey(session, id, label, uintPtr(pkcs11.CKO_PUBLIC_KEY), &keyType) + if err != nil { + return err + } + if pubHandle == nil { + // We can't return a Signer if we don't have private and public key. Treat it as an error. + return errors.New("could not find public key to match private key") + } + + switch keyType { + case pkcs11.CKK_DSA: + if pub, err = exportDSAPublicKey(session, *pubHandle); err != nil { + return err + } + k = &pkcs11PrivateKeyDSA{ + pkcs11PrivateKey: pkcs11PrivateKey{ + pkcs11Object: pkcs11Object{ + handle: *privHandle, + context: c, + }, + pubKeyHandle: *pubHandle, + pubKey: pub, + }} + + case pkcs11.CKK_RSA: + if pub, err = exportRSAPublicKey(session, *pubHandle); err != nil { + return err + } + k = &pkcs11PrivateKeyRSA{ + pkcs11PrivateKey: pkcs11PrivateKey{ + pkcs11Object: pkcs11Object{ + handle: *privHandle, + context: c, + }, + pubKeyHandle: *pubHandle, + pubKey: pub, + }} + + case pkcs11.CKK_ECDSA: + if pub, err = exportECDSAPublicKey(session, *pubHandle); err != nil { + return err + } + k = &pkcs11PrivateKeyECDSA{ + pkcs11PrivateKey: pkcs11PrivateKey{ + pkcs11Object: pkcs11Object{ + handle: *privHandle, + context: c, + }, + pubKeyHandle: *pubHandle, + pubKey: pub, + }} + + default: + return errors.Errorf("unsupported key type: %X", keyType) + } + + return nil + }) + return k, err } // Public returns the public half of a private key. // // This partially implements the go.crypto.Signer and go.crypto.Decrypter interfaces for -// PKCS11PrivateKey. (The remains of the implementation is in the +// pkcs11PrivateKey. (The remains of the implementation is in the // key-specific types.) -func (signer PKCS11PrivateKey) Public() crypto.PublicKey { - return signer.PubKey +func (k pkcs11PrivateKey) Public() crypto.PublicKey { + return k.pubKey } -// FindKey retrieves a previously created symmetric key. +// FindKey retrieves a previously created symmetric key, or nil if it cannot be found. // // Either (but not both) of id and label may be nil, in which case they are ignored. -func FindKey(id []byte, label []byte) (*PKCS11SecretKey, error) { - return FindKeyOnSlot(instance.slot, id, label) -} +func (c *Context) FindKey(id []byte, label []byte) (*SecretKey, error) { -// FindKeyOnSlot retrieves a previously created symmetric key, using a specified slot. -// -// Either (but not both) of id and label may be nil, in which case they are ignored. -func FindKeyOnSlot(slot uint, id []byte, label []byte) (*PKCS11SecretKey, error) { - var err error - var k *PKCS11SecretKey - if err = ensureSessions(instance, slot); err != nil { - return nil, err + if c.closed.Get() { + return nil, errClosed } - err = withSession(slot, func(session *PKCS11Session) error { - k, err = FindKeyOnSession(session, slot, id, label) - return err + + var k *SecretKey + + err := c.withSession(func(session *pkcs11Session) error { + privHandle, err := findKey(session, id, label, uintPtr(pkcs11.CKO_SECRET_KEY), nil) + if err != nil { + return err + } + if privHandle == nil { + // Key does not exist + return nil + } + + attributes := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, 0), + } + if attributes, err = session.ctx.GetAttributeValue(session.handle, *privHandle, attributes); err != nil { + return err + } + keyType := bytesToUlong(attributes[0].Value) + + if cipher, ok := Ciphers[int(keyType)]; ok { + k = &SecretKey{pkcs11Object{*privHandle, c}, cipher} + } else { + return errors.Errorf("unsupported key type: %X", keyType) + } + return nil }) + return k, err } -// FindKeyOnSession retrieves a previously created symmetric key, using a specified session. -// -// Either (but not both) of id and label may be nil, in which case they are ignored. -func FindKeyOnSession(session *PKCS11Session, slot uint, id []byte, label []byte) (key *PKCS11SecretKey, err error) { - var privHandle pkcs11.ObjectHandle - if privHandle, err = findKey(session, id, label, pkcs11.CKO_SECRET_KEY, ^uint(0)); err != nil { - return - } - attributes := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, 0), - } - if attributes, err = session.Ctx.GetAttributeValue(session.Handle, privHandle, attributes); err != nil { - return - } - if cipher, ok := Ciphers[int(bytesToUlong(attributes[0].Value))]; ok { - key = &PKCS11SecretKey{PKCS11Object{privHandle, slot}, cipher} - } else { - err = ErrUnsupportedKeyType - return - } - return -} +func uintPtr(i uint) *uint { return &i } diff --git a/keys_test.go b/keys_test.go new file mode 100644 index 0000000..152b33a --- /dev/null +++ b/keys_test.go @@ -0,0 +1,23 @@ +package crypto11 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFindKeysRequiresIdOrLabel(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + _, err = ctx.FindKey(nil, nil) + assert.Error(t, err) + + _, err = ctx.FindKeyPair(nil, nil) + assert.Error(t, err) +} diff --git a/pool_test.go b/pool_test.go deleted file mode 100644 index 5eb0d8e..0000000 --- a/pool_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2018 Thales e-Security, Inc -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -package crypto11 - -import ( - "crypto" - "crypto/elliptic" - "crypto/rand" - "fmt" - "github.com/miekg/pkcs11" - "github.com/stretchr/testify/require" - "testing" - "time" -) - -func TestPoolTimeout(t *testing.T) { - for _, d := range []time.Duration{0, time.Second} { - t.Run(fmt.Sprintf("first login, exp %v", d), func(t *testing.T) { - prevIdleTimeout := instance.cfg.IdleTimeout - defer func() { instance.cfg.IdleTimeout = prevIdleTimeout }() - instance.cfg.IdleTimeout = d - - _, err := configureWithPin(t) - require.NoError(t, err) - - defer func() { - require.NoError(t, Close()) - }() - - time.Sleep(instance.cfg.IdleTimeout + time.Second) - - _, err = GenerateECDSAKeyPair(elliptic.P256()) - if err != nil { - if perr, ok := err.(pkcs11.Error); ok && perr == pkcs11.CKR_USER_NOT_LOGGED_IN { - t.Fatal("pool handle session incorrectly, login required but missing:", err) - } else { - t.Fatal("failed to generate a key, unexpected error:", err) - } - } - }) - - t.Run(fmt.Sprintf("reuse expired handle, exp %v", d), func(t *testing.T) { - prevIdleTimeout := instance.cfg.IdleTimeout - defer func() { instance.cfg.IdleTimeout = prevIdleTimeout }() - instance.cfg.IdleTimeout = d - - _, err := configureWithPin(t) - require.NoError(t, err) - - defer func() { - require.NoError(t, Close()) - }() - - key, err := GenerateECDSAKeyPair(elliptic.P256()) - if err != nil { - t.Fatal("failed to generate a key:", err) - } - - time.Sleep(instance.cfg.IdleTimeout + time.Second) - - digest := crypto.SHA256.New() - digest.Write([]byte("sha256")) - _, err = key.Sign(rand.Reader, digest.Sum(nil), crypto.SHA256) - if err != nil { - if perr, ok := err.(pkcs11.Error); !ok || perr != pkcs11.CKR_OBJECT_HANDLE_INVALID { - t.Fatal("failed to reuse existing key handle, unexpected error:", err) - } - } - }) - } -} diff --git a/rand.go b/rand.go index 6f1ef24..c820e55 100644 --- a/rand.go +++ b/rand.go @@ -21,20 +21,30 @@ package crypto11 -// PKCS11RandReader is a random number reader that uses PKCS#11. -type PKCS11RandReader struct { +import ( + "io" +) + +// NewRandomReader returns a reader for the random number generator on the token. +func (c *Context) NewRandomReader() (io.Reader, error) { + if c.closed.Get() { + return nil, errClosed + } + + return pkcs11RandReader{c}, nil } -// Read fills data with random bytes generated via PKCS#11 using the default slot. -// -// This implements the Reader interface for PKCS11RandReader. -func (reader PKCS11RandReader) Read(data []byte) (n int, err error) { +// pkcs11RandReader is a random number reader that uses PKCS#11. +type pkcs11RandReader struct { + context *Context +} + +// This implements the Reader interface for pkcs11RandReader. +func (r pkcs11RandReader) Read(data []byte) (n int, err error) { var result []byte - if instance.ctx == nil { - return 0, ErrNotConfigured - } - if err = withSession(instance.slot, func(session *PKCS11Session) error { - result, err = instance.ctx.GenerateRandom(session.Handle, len(data)) + + if err = r.context.withSession(func(session *pkcs11Session) error { + result, err = r.context.ctx.GenerateRandom(session.handle, len(data)) return err }); err != nil { return 0, err diff --git a/rand_test.go b/rand_test.go index 52ea4fe..f0fea63 100644 --- a/rand_test.go +++ b/rand_test.go @@ -22,26 +22,27 @@ package crypto11 import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestRandomReader(t *testing.T) { - var a [32768]byte - var r PKCS11RandReader - var n int - _, err := ConfigureFromFile("config") + ctx, err := ConfigureFromFile("config") require.NoError(t, err) + defer func() { + err = ctx.Close() + require.NoError(t, err) + }() + + reader, err := ctx.NewRandomReader() + require.NoError(t, err) + + var a [32768]byte for _, size := range []int{1, 16, 32, 256, 347, 4096, 32768} { - if n, err = r.Read(a[:size]); err != nil { - t.Errorf("crypto11.PKCS11RandRead.Read: %v", err) - return - } - if n < size { - t.Errorf("crypto11.PKCS11RandRead.Read: only got %d bytes expected %d", n, size) - return - } + n, err := reader.Read(a[:size]) + require.NoError(t, err) + require.Equal(t, size, n) } - require.NoError(t, Close()) } diff --git a/rsa.go b/rsa.go index df1eb79..9b1554d 100644 --- a/rsa.go +++ b/rsa.go @@ -29,38 +29,34 @@ import ( "math/big" "unsafe" - pkcs11 "github.com/miekg/pkcs11" + "github.com/miekg/pkcs11" ) -// ErrMalformedRSAKey is returned when an RSA key is not in a suitable form. +// errMalformedRSAPublicKey is returned when an RSA public key is not in a suitable form. // // Currently this means that the public exponent is either bigger than // 32 bits, or less than 2. -var ErrMalformedRSAKey = errors.New("crypto11/rsa: malformed RSA key") +var errMalformedRSAPublicKey = errors.New("malformed RSA public key") -// ErrUnrecognizedRSAOptions is returned when unrecognized options -// structures are pased to Sign or Decrypt. -var ErrUnrecognizedRSAOptions = errors.New("crypto11/rsa: unrecognized RSA options type") - -// ErrUnsupportedRSAOptions is returned when an unsupported RSA option is requested. +// errUnsupportedRSAOptions is returned when an unsupported RSA option is requested. // // Currently this means a nontrivial SessionKeyLen when decrypting; or // an unsupported hash function; or crypto.rsa.PSSSaltLengthAuto was // requested. -var ErrUnsupportedRSAOptions = errors.New("crypto11/rsa: unsupported RSA option value") +var errUnsupportedRSAOptions = errors.New("unsupported RSA option value") -// PKCS11PrivateKeyRSA contains a reference to a loaded PKCS#11 RSA private key object. -type PKCS11PrivateKeyRSA struct { - PKCS11PrivateKey +// pkcs11PrivateKeyRSA contains a reference to a loaded PKCS#11 RSA private key object. +type pkcs11PrivateKeyRSA struct { + pkcs11PrivateKey } // Export the public key corresponding to a private RSA key. -func exportRSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) (crypto.PublicKey, error) { +func exportRSAPublicKey(session *pkcs11Session, pubHandle pkcs11.ObjectHandle) (crypto.PublicKey, error) { template := []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil), pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, nil), } - exported, err := session.Ctx.GetAttributeValue(session.Handle, pubHandle, template) + exported, err := session.ctx.GetAttributeValue(session.handle, pubHandle, template) if err != nil { return nil, err } @@ -69,10 +65,10 @@ func exportRSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) ( var bigExponent = new(big.Int) bigExponent.SetBytes(exported[1].Value) if bigExponent.BitLen() > 32 { - return nil, ErrMalformedRSAKey + return nil, errMalformedRSAPublicKey } if bigExponent.Sign() < 1 { - return nil, ErrMalformedRSAKey + return nil, errMalformedRSAPublicKey } exponent := int(bigExponent.Uint64()) result := rsa.PublicKey{ @@ -80,101 +76,112 @@ func exportRSAPublicKey(session *PKCS11Session, pubHandle pkcs11.ObjectHandle) ( E: exponent, } if result.E < 2 { - return nil, ErrMalformedRSAKey + return nil, errMalformedRSAPublicKey } return &result, nil } -// GenerateRSAKeyPair creates an RSA private key of given length. -// -// The key will have a random label and ID. -// -// RSA private keys are generated with both sign and decrypt -// permissions, and a public exponent of 65537. -func GenerateRSAKeyPair(bits int) (*PKCS11PrivateKeyRSA, error) { - return GenerateRSAKeyPairOnSlot(instance.slot, nil, nil, bits) -} +// GenerateRSAKeyPair creates an RSA key pair on the token. The id parameter is used to +// set CKA_ID and must be non-nil. +func (c *Context) GenerateRSAKeyPair(id []byte, bits int) (SignerDecrypter, error) { + if c.closed.Get() { + return nil, errClosed + } -// GenerateRSAKeyPairOnSlot creates a RSA private key on a specified slot -// -// Either or both label and/or id can be nil, in which case random values will be generated. -func GenerateRSAKeyPairOnSlot(slot uint, id []byte, label []byte, bits int) (*PKCS11PrivateKeyRSA, error) { - var k *PKCS11PrivateKeyRSA - var err error - if err = ensureSessions(instance, slot); err != nil { + if err := notNilBytes(id, "id"); err != nil { return nil, err } - err = withSession(slot, func(session *PKCS11Session) error { - k, err = GenerateRSAKeyPairOnSession(session, slot, id, label, bits) - return err - }) - return k, err -} -// GenerateRSAKeyPairOnSession creates an RSA private key of given length, on a specified session. -// -// Either or both label and/or id can be nil, in which case random values will be generated. -// -// RSA private keys are generated with both sign and decrypt -// permissions, and a public exponent of 65537. -func GenerateRSAKeyPairOnSession(session *PKCS11Session, slot uint, id []byte, label []byte, bits int) (*PKCS11PrivateKeyRSA, error) { - var err error - var pub crypto.PublicKey + return c.generateRSAKeyPair(id, nil, bits) +} - if label == nil { - if label, err = generateKeyLabel(); err != nil { - return nil, err - } - } - if id == nil { - if id, err = generateKeyLabel(); err != nil { - return nil, err - } - } - publicKeyTemplate := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_RSA), - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), - pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), - pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, []byte{1, 0, 1}), - pkcs11.NewAttribute(pkcs11.CKA_MODULUS_BITS, bits), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), - pkcs11.NewAttribute(pkcs11.CKA_ID, id), +// GenerateRSAKeyPairWithLabel creates an RSA key pair on the token. The id and label parameters are used to +// set CKA_ID and CKA_LABEL respectively and must be non-nil. +func (c *Context) GenerateRSAKeyPairWithLabel(id, label []byte, bits int) (SignerDecrypter, error) { + if c.closed.Get() { + return nil, errClosed } - privateKeyTemplate := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), - pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), - pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), - pkcs11.NewAttribute(pkcs11.CKA_ID, id), - } - mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_KEY_PAIR_GEN, nil)} - pubHandle, privHandle, err := session.Ctx.GenerateKeyPair(session.Handle, - mech, - publicKeyTemplate, - privateKeyTemplate) - if err != nil { + + if err := notNilBytes(id, "id"); err != nil { return nil, err } - if pub, err = exportRSAPublicKey(session, pubHandle); err != nil { + if err := notNilBytes(label, "label"); err != nil { return nil, err } - priv := PKCS11PrivateKeyRSA{PKCS11PrivateKey{PKCS11Object{privHandle, slot}, pub}} - return &priv, nil + + return c.generateRSAKeyPair(id, label, bits) +} + +// GenerateRSAKeyPair creates an RSA private key of given length. The CKA_ID and CKA_LABEL attributes can be set by passing +// non-nil values for id and label. +// +// RSA private keys are generated with both sign and decrypt permissions, and a public exponent of 65537. +func (c *Context) generateRSAKeyPair(id, label []byte, bits int) (k SignerDecrypter, err error) { + err = c.withSession(func(session *pkcs11Session) error { + + publicKeyTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_RSA), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, []byte{1, 0, 1}), + pkcs11.NewAttribute(pkcs11.CKA_MODULUS_BITS, bits), + } + privateKeyTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), + } + + if id != nil { + publicKeyTemplate = append(publicKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + privateKeyTemplate = append(privateKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + + if label != nil { + publicKeyTemplate = append(publicKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + privateKeyTemplate = append(privateKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_KEY_PAIR_GEN, nil)} + pubHandle, privHandle, err := session.ctx.GenerateKeyPair(session.handle, + mech, + publicKeyTemplate, + privateKeyTemplate) + if err != nil { + return err + } + + pub, err := exportRSAPublicKey(session, pubHandle) + if err != nil { + return err + } + k = &pkcs11PrivateKeyRSA{ + pkcs11PrivateKey: pkcs11PrivateKey{ + pkcs11Object: pkcs11Object{ + handle: privHandle, + context: c, + }, + pubKeyHandle: pubHandle, + pubKey: pub, + }} + return nil + }) + return } // Decrypt decrypts a message using a RSA key. // -// This completes the implemention of crypto.Decrypter for PKCS11PrivateKeyRSA. +// This completes the implemention of crypto.Decrypter for pkcs11PrivateKeyRSA. // // Note that the SessionKeyLen option (for PKCS#1v1.5 decryption) is not supported. // // The underlying PKCS#11 implementation may impose further restrictions. -func (priv *PKCS11PrivateKeyRSA) Decrypt(rand io.Reader, ciphertext []byte, options crypto.DecrypterOpts) (plaintext []byte, err error) { - err = withSession(priv.Slot, func(session *PKCS11Session) error { +func (priv *pkcs11PrivateKeyRSA) Decrypt(rand io.Reader, ciphertext []byte, options crypto.DecrypterOpts) (plaintext []byte, err error) { + err = priv.context.withSession(func(session *pkcs11Session) error { if options == nil { plaintext, err = decryptPKCS1v15(session, priv, ciphertext, 0) } else { @@ -184,7 +191,7 @@ func (priv *PKCS11PrivateKeyRSA) Decrypt(rand io.Reader, ciphertext []byte, opti case *rsa.OAEPOptions: plaintext, err = decryptOAEP(session, priv, ciphertext, o.Hash, o.Label) default: - err = ErrUnsupportedRSAOptions + err = errUnsupportedRSAOptions } } return err @@ -192,24 +199,24 @@ func (priv *PKCS11PrivateKeyRSA) Decrypt(rand io.Reader, ciphertext []byte, opti return plaintext, err } -func decryptPKCS1v15(session *PKCS11Session, key *PKCS11PrivateKeyRSA, ciphertext []byte, sessionKeyLen int) ([]byte, error) { +func decryptPKCS1v15(session *pkcs11Session, key *pkcs11PrivateKeyRSA, ciphertext []byte, sessionKeyLen int) ([]byte, error) { if sessionKeyLen != 0 { - return nil, ErrUnsupportedRSAOptions + return nil, errUnsupportedRSAOptions } mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS, nil)} - if err := session.Ctx.DecryptInit(session.Handle, mech, key.Handle); err != nil { + if err := session.ctx.DecryptInit(session.handle, mech, key.handle); err != nil { return nil, err } - return session.Ctx.Decrypt(session.Handle, ciphertext) + return session.ctx.Decrypt(session.handle, ciphertext) } -func decryptOAEP(session *PKCS11Session, key *PKCS11PrivateKeyRSA, ciphertext []byte, hashFunction crypto.Hash, label []byte) ([]byte, error) { +func decryptOAEP(session *pkcs11Session, key *pkcs11PrivateKeyRSA, ciphertext []byte, hashFunction crypto.Hash, label []byte) ([]byte, error) { var err error var hMech, mgf, sourceData, sourceDataLen uint if hMech, mgf, _, err = hashToPKCS11(hashFunction); err != nil { return nil, err } - if label != nil && len(label) > 0 { + if len(label) > 0 { sourceData = uint(uintptr(unsafe.Pointer(&label[0]))) sourceDataLen = uint(len(label)) } @@ -219,10 +226,10 @@ func decryptOAEP(session *PKCS11Session, key *PKCS11PrivateKeyRSA, ciphertext [] ulongToBytes(sourceData), ulongToBytes(sourceDataLen)) mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_OAEP, parameters)} - if err = session.Ctx.DecryptInit(session.Handle, mech, key.Handle); err != nil { + if err = session.ctx.DecryptInit(session.handle, mech, key.handle); err != nil { return nil, err } - return session.Ctx.Decrypt(session.Handle, ciphertext) + return session.ctx.Decrypt(session.handle, ciphertext) } func hashToPKCS11(hashFunction crypto.Hash) (uint, uint, uint, error) { @@ -238,11 +245,11 @@ func hashToPKCS11(hashFunction crypto.Hash) (uint, uint, uint, error) { case crypto.SHA512: return pkcs11.CKM_SHA512, pkcs11.CKG_MGF1_SHA512, 64, nil default: - return 0, 0, 0, ErrUnsupportedRSAOptions + return 0, 0, 0, errUnsupportedRSAOptions } } -func signPSS(session *PKCS11Session, key *PKCS11PrivateKeyRSA, digest []byte, opts *rsa.PSSOptions) ([]byte, error) { +func signPSS(session *pkcs11Session, key *pkcs11PrivateKeyRSA, digest []byte, opts *rsa.PSSOptions) ([]byte, error) { var hMech, mgf, hLen, sLen uint var err error if hMech, mgf, hLen, err = hashToPKCS11(opts.Hash); err != nil { @@ -253,7 +260,7 @@ func signPSS(session *PKCS11Session, key *PKCS11PrivateKeyRSA, digest []byte, op // TODO we could (in principle) work out the biggest // possible size from the key, but until someone has // the effort to do that... - return nil, ErrUnsupportedRSAOptions + return nil, errUnsupportedRSAOptions case rsa.PSSSaltLengthEqualsHash: sLen = hLen default: @@ -265,37 +272,37 @@ func signPSS(session *PKCS11Session, key *PKCS11PrivateKeyRSA, digest []byte, op ulongToBytes(mgf), ulongToBytes(sLen)) mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_PSS, parameters)} - if err = session.Ctx.SignInit(session.Handle, mech, key.Handle); err != nil { + if err = session.ctx.SignInit(session.handle, mech, key.handle); err != nil { return nil, err } - return session.Ctx.Sign(session.Handle, digest) + return session.ctx.Sign(session.handle, digest) } var pkcs1Prefix = map[crypto.Hash][]byte{ - crypto.SHA1: []byte{0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14}, - crypto.SHA224: []byte{0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, 0x00, 0x04, 0x1c}, - crypto.SHA256: []byte{0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20}, - crypto.SHA384: []byte{0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30}, - crypto.SHA512: []byte{0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40}, + crypto.SHA1: {0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14}, + crypto.SHA224: {0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, 0x00, 0x04, 0x1c}, + crypto.SHA256: {0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20}, + crypto.SHA384: {0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30}, + crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40}, } -func signPKCS1v15(session *PKCS11Session, key *PKCS11PrivateKeyRSA, digest []byte, hash crypto.Hash) (signature []byte, err error) { +func signPKCS1v15(session *pkcs11Session, key *pkcs11PrivateKeyRSA, digest []byte, hash crypto.Hash) (signature []byte, err error) { /* Calculate T for EMSA-PKCS1-v1_5. */ oid := pkcs1Prefix[hash] T := make([]byte, len(oid)+len(digest)) copy(T[0:len(oid)], oid) copy(T[len(oid):], digest) mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS, nil)} - err = session.Ctx.SignInit(session.Handle, mech, key.Handle) + err = session.ctx.SignInit(session.handle, mech, key.handle) if err == nil { - signature, err = session.Ctx.Sign(session.Handle, T) + signature, err = session.ctx.Sign(session.handle, T) } return } // Sign signs a message using a RSA key. // -// This completes the implemention of crypto.Signer for PKCS11PrivateKeyRSA. +// This completes the implemention of crypto.Signer for pkcs11PrivateKeyRSA. // // PKCS#11 expects to pick its own random data where necessary for signatures, so the rand argument is ignored. // @@ -304,11 +311,8 @@ func signPKCS1v15(session *PKCS11Session, key *PKCS11PrivateKeyRSA, digest []byt // crypto.rsa.PSSSaltLengthEqualsHash (recommended) or pass an // explicit salt length. Moreover the underlying PKCS#11 // implementation may impose further restrictions. -func (priv *PKCS11PrivateKeyRSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { - if err != nil { - return nil, err - } - err = withSession(priv.Slot, func(session *PKCS11Session) error { +func (priv *pkcs11PrivateKeyRSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + err = priv.context.withSession(func(session *pkcs11Session) error { switch opts.(type) { case *rsa.PSSOptions: signature, err = signPSS(session, priv, digest, opts.(*rsa.PSSOptions)) @@ -317,22 +321,10 @@ func (priv *PKCS11PrivateKeyRSA) Sign(rand io.Reader, digest []byte, opts crypto } return err }) - return signature, err -} -// Validate checks an RSA key. -// -// Since the private key material is not normally available only very -// limited validation is possible. (The underlying PKCS#11 -// implementation may perform stricter checking.) -func (priv *PKCS11PrivateKeyRSA) Validate() error { - pub := priv.PubKey.(*rsa.PublicKey) - if pub.E < 2 { - return ErrMalformedRSAKey + if err != nil { + return nil, err } - // The software implementation actively rejects 'large' public - // exponents, in order to simplify its own implementation. - // Here, instead, we expect the PKCS#11 library to enforce its - // own preferred constraints, whatever they might be. - return nil + + return signature, err } diff --git a/rsa_test.go b/rsa_test.go index e6a3f27..ba90a35 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -30,111 +30,96 @@ import ( _ "crypto/sha256" _ "crypto/sha512" "fmt" + "testing" + "github.com/miekg/pkcs11" "github.com/stretchr/testify/require" - "testing" ) var rsaSizes = []int{1024, 2048} func TestNativeRSA(t *testing.T) { - var key *rsa.PrivateKey - _, err := ConfigureFromFile("config") + + ctx, err := ConfigureFromFile("config") require.NoError(t, err) + defer func() { + require.NoError(t, ctx.Close()) + }() + for _, nbits := range rsaSizes { t.Run(fmt.Sprintf("%v", nbits), func(t *testing.T) { - t.Run("Generate", func(t *testing.T) { - if key, err = rsa.GenerateKey(rand.Reader, nbits); err != nil { - t.Errorf("crypto.rsa.GenerateKey: %v", err) - return - } - if err = key.Validate(); err != nil { - t.Errorf("crypto.rsa.PrivateKey.Validate: %v", err) - return - } - }) - t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits, ^uint(0)) }) - t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits, ^uint(0)) }) + key, err := rsa.GenerateKey(rand.Reader, nbits) + require.NoError(t, err) + + err = key.Validate() + require.NoError(t, err) + + t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits, true) }) + t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits, true) }) }) } - - require.NoError(t, Close()) } func TestHardRSA(t *testing.T) { - var key *PKCS11PrivateKeyRSA - var key2, key3 crypto.PrivateKey - var id, label []byte - - _, err := ConfigureFromFile("config") + ctx, err := ConfigureFromFile("config") require.NoError(t, err) + defer func() { + require.NoError(t, ctx.Close()) + }() + for _, nbits := range rsaSizes { + id := randomBytes() + label := randomBytes() + t.Run(fmt.Sprintf("%v", nbits), func(t *testing.T) { - t.Run("Generate", func(t *testing.T) { - if key, err = GenerateRSAKeyPair(nbits); err != nil { - t.Errorf("crypto11.GenerateRSAKeyPair: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.dsa.GenerateRSAKeyPair: returned nil but no error") - return - } - if err = key.Validate(); err != nil { - t.Errorf("crypto11.rsa.PKCS11PrivateKeyRSA.Validate: %v", err) - return - } - }) - t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits, key.Slot) }) - t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits, key.Slot) }) + + key, err := ctx.GenerateRSAKeyPairWithLabel(id, label, nbits) + require.NoError(t, err) + require.NotNil(t, key) + + var key2, key3 crypto.PrivateKey + + t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits, false) }) + t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits, false) }) t.Run("FindId", func(t *testing.T) { - // Get a fresh handle to the key - if id, label, err = key.Identify(); err != nil { - t.Errorf("crypto11.rsa.PKCS11PrivateKeyRSA.Identify: %v", err) - return - } - if key2, err = FindKeyPair(id, nil); err != nil { - t.Errorf("crypto11.rsa.FindRSAKeyPair by id: %v", err) - return - } + key2, err = ctx.FindKeyPair(id, nil) + require.NoError(t, err) }) t.Run("SignId", func(t *testing.T) { if key2 == nil { t.SkipNow() } - testRsaSigning(t, key2.(*PKCS11PrivateKeyRSA), nbits, key2.(*PKCS11PrivateKeyRSA).Slot) + testRsaSigning(t, key2.(*pkcs11PrivateKeyRSA), nbits, false) }) t.Run("FindLabel", func(t *testing.T) { - if key3, err = FindKeyPair(nil, label); err != nil { - t.Errorf("crypto11.rsa.FindKeyPair by label: %v", err) - return - } + key3, err = ctx.FindKeyPair(nil, label) + require.NoError(t, err) }) t.Run("SignLabel", func(t *testing.T) { if key3 == nil { t.SkipNow() } - testRsaSigning(t, key3.(crypto.Signer), nbits, key3.(*PKCS11PrivateKeyRSA).Slot) + testRsaSigning(t, key3.(crypto.Signer), nbits, false) }) }) } - require.NoError(t, Close()) } -func testRsaSigning(t *testing.T, key crypto.Signer, nbits int, slot uint) { +func testRsaSigning(t *testing.T, key crypto.Signer, nbits int, native bool) { t.Run("SHA1", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA1) }) t.Run("SHA224", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA224) }) t.Run("SHA256", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA256) }) t.Run("SHA384", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA384) }) t.Run("SHA512", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA512) }) - t.Run("PSSSHA1", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA1, slot) }) - t.Run("PSSSHA224", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA224, slot) }) - t.Run("PSSSHA256", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA256, slot) }) - t.Run("PSSSHA384", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA384, slot) }) + t.Run("PSSSHA1", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA1, native) }) + t.Run("PSSSHA224", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA224, native) }) + t.Run("PSSSHA256", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA256, native) }) + t.Run("PSSSHA384", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA384, native) }) t.Run("PSSSHA512", func(t *testing.T) { if nbits > 1024 { - testRsaSigningPSS(t, key, crypto.SHA512, slot) + testRsaSigningPSS(t, key, crypto.SHA512, native) } else { t.Skipf("key too smol for SHA512 with sLen=hLen") } @@ -142,68 +127,69 @@ func testRsaSigning(t *testing.T, key crypto.Signer, nbits int, slot uint) { } func testRsaSigningPKCS1v15(t *testing.T, key crypto.Signer, hashFunction crypto.Hash) { - var err error - var sig []byte - plaintext := []byte("sign me with PKCS#1 v1.5") h := hashFunction.New() - h.Write(plaintext) + _, err := h.Write(plaintext) + require.NoError(t, err) plaintextHash := h.Sum([]byte{}) // weird API - if sig, err = key.Sign(rand.Reader, plaintextHash, hashFunction); err != nil { - t.Errorf("PKCS#1 v1.5 Sign (hash %v): %v", hashFunction, err) - return - } + + sig, err := key.Sign(rand.Reader, plaintextHash, hashFunction) + require.NoError(t, err) + rsaPubkey := key.Public().(crypto.PublicKey).(*rsa.PublicKey) - if err = rsa.VerifyPKCS1v15(rsaPubkey, hashFunction, plaintextHash, sig); err != nil { - t.Errorf("PKCS#1 v1.5 Verify (hash %v): %v", hashFunction, err) - } + err = rsa.VerifyPKCS1v15(rsaPubkey, hashFunction, plaintextHash, sig) + require.NoError(t, err) } -func testRsaSigningPSS(t *testing.T, key crypto.Signer, hashFunction crypto.Hash, slot uint) { - var err error - var sig []byte +func testRsaSigningPSS(t *testing.T, key crypto.Signer, hashFunction crypto.Hash, native bool) { + + if !native { + skipIfMechUnsupported(t, key.(*pkcs11PrivateKeyRSA).context, pkcs11.CKM_RSA_PKCS_PSS) + } - needMechanism(t, slot, pkcs11.CKM_RSA_PKCS_PSS) plaintext := []byte("sign me with PSS") h := hashFunction.New() - h.Write(plaintext) + _, err := h.Write(plaintext) + require.NoError(t, err) + plaintextHash := h.Sum([]byte{}) // weird API pssOptions := &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunction, } - if sig, err = key.Sign(rand.Reader, plaintextHash, pssOptions); err != nil { - t.Errorf("PSS Sign (hash %v): %v", hashFunction, err) - return - } + sig, err := key.Sign(rand.Reader, plaintextHash, pssOptions) + require.NoError(t, err) + rsaPubkey := key.Public().(crypto.PublicKey).(*rsa.PublicKey) - if err = rsa.VerifyPSS(rsaPubkey, hashFunction, plaintextHash, sig, pssOptions); err != nil { - t.Errorf("PSS Verify (hash %v): %v", hashFunction, err) - } + + err = rsa.VerifyPSS(rsaPubkey, hashFunction, plaintextHash, sig, pssOptions) + require.NoError(t, err) } -func testRsaEncryption(t *testing.T, key crypto.Decrypter, nbits int, slot uint) { +func testRsaEncryption(t *testing.T, key crypto.Decrypter, nbits int, native bool) { t.Run("PKCS1v15", func(t *testing.T) { testRsaEncryptionPKCS1v15(t, key) }) - t.Run("OAEPSHA1", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{}, slot) }) - t.Run("OAEPSHA224", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{}, slot) }) - t.Run("OAEPSHA256", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{}, slot) }) - t.Run("OAEPSHA384", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{}, slot) }) + t.Run("OAEPSHA1", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{}, native) }) + t.Run("OAEPSHA224", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{}, native) }) + t.Run("OAEPSHA256", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{}, native) }) + t.Run("OAEPSHA384", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{}, native) }) t.Run("OAEPSHA512", func(t *testing.T) { if nbits > 1024 { - testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{}, slot) + testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{}, native) } else { - t.Skipf("key too smol for SHA512") + t.Skipf("key too small for SHA512") } }) - t.Run("OAEPSHA1Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{1, 2, 3, 4}, slot) }) - t.Run("OAEPSHA224Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{5, 6, 7, 8}, slot) }) - t.Run("OAEPSHA256Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{9}, slot) }) - t.Run("OAEPSHA384Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{10, 11, 12, 13, 14, 15}, slot) }) + t.Run("OAEPSHA1Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{1, 2, 3, 4}, native) }) + t.Run("OAEPSHA224Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{5, 6, 7, 8}, native) }) + t.Run("OAEPSHA256Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{9}, native) }) + t.Run("OAEPSHA384Label", func(t *testing.T) { + testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{10, 11, 12, 13, 14, 15}, native) + }) t.Run("OAEPSHA512Label", func(t *testing.T) { if nbits > 1024 { - testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{16, 17, 18}, slot) + testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{16, 17, 18}, native) } else { - t.Skipf("key too smol for SHA512") + t.Skipf("key too small for SHA512") } }) } @@ -222,7 +208,7 @@ func testRsaEncryptionPKCS1v15(t *testing.T, key crypto.Decrypter) { t.Errorf("PKCS#1v1.5 Decrypt (nil options): %v", err) return } - if bytes.Compare(plaintext, decrypted) != 0 { + if !bytes.Equal(plaintext, decrypted) { t.Errorf("PKCS#1v1.5 Decrypt (nil options): wrong answer") return } @@ -233,57 +219,46 @@ func testRsaEncryptionPKCS1v15(t *testing.T, key crypto.Decrypter) { t.Errorf("PKCS#1v1.5 Decrypt %v", err) return } - if bytes.Compare(plaintext, decrypted) != 0 { + if !bytes.Equal(plaintext, decrypted) { t.Errorf("PKCS#1v1.5 Decrypt: wrong answer") return } } -func testRsaEncryptionOAEP(t *testing.T, key crypto.Decrypter, hashFunction crypto.Hash, label []byte, slot uint) { - var err error - var ciphertext, decrypted []byte - needMechanism(t, slot, pkcs11.CKM_RSA_PKCS_OAEP) - // Doesn't seem to be a way to query supported MGFs so we do that the hard way. - var info pkcs11.Info - if info, err = instance.ctx.GetInfo(); err != nil { - t.Errorf("GetInfo: %v", err) - return - } - if info.ManufacturerID == "SoftHSM" && (hashFunction != crypto.SHA1 || len(label) > 0) { - t.Skipf("SoftHSM OAEP only supports SHA-1 with no label") +func testRsaEncryptionOAEP(t *testing.T, key crypto.Decrypter, hashFunction crypto.Hash, label []byte, native bool) { + if !native { + skipIfMechUnsupported(t, key.(*pkcs11PrivateKeyRSA).context, pkcs11.CKM_RSA_PKCS_OAEP) + + // Doesn't seem to be a way to query supported MGFs so we do that the hard way. + info, err := key.(*pkcs11PrivateKeyRSA).context.ctx.GetInfo() + require.NoError(t, err) + + if info.ManufacturerID == "SoftHSM" && (hashFunction != crypto.SHA1 || len(label) > 0) { + t.Skipf("SoftHSM OAEP only supports SHA-1 with no label") + } } + plaintext := []byte("encrypt me with new hotness") h := hashFunction.New() rsaPubkey := key.Public().(crypto.PublicKey).(*rsa.PublicKey) - if ciphertext, err = rsa.EncryptOAEP(h, rand.Reader, rsaPubkey, plaintext, label); err != nil { - t.Errorf("OAEP Encrypt: %v", err) - return - } + + ciphertext, err := rsa.EncryptOAEP(h, rand.Reader, rsaPubkey, plaintext, label) + require.NoError(t, err) + options := &rsa.OAEPOptions{ Hash: hashFunction, Label: label, } - if decrypted, err = key.Decrypt(rand.Reader, ciphertext, options); err != nil { - t.Errorf("OAEP Decrypt %v", err) - return - } - if bytes.Compare(plaintext, decrypted) != 0 { - t.Errorf("OAEP Decrypt: wrong answer") - return - } + decrypted, err := key.Decrypt(rand.Reader, ciphertext, options) + require.NoError(t, err) + + require.Equal(t, plaintext, decrypted) } -func needMechanism(t *testing.T, slot uint, wantMech uint) { - var err error - var mechs []*pkcs11.Mechanism +func skipIfMechUnsupported(t *testing.T, ctx *Context, wantMech uint) { + mechs, err := ctx.ctx.GetMechanismList(ctx.slot) + require.NoError(t, err) - if slot == ^uint(0) { // not using PKCS#11 - return - } - if mechs, err = instance.ctx.GetMechanismList(slot); err != nil { - t.Errorf("GetMechanismList: %v", err) - return - } for _, mech := range mechs { if mech.Mechanism == wantMech { return @@ -291,3 +266,23 @@ func needMechanism(t *testing.T, slot uint, wantMech uint) { } t.Skipf("mechanism %v not supported", wantMech) } + +func TestRsaRequiredArgs(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + _, err = ctx.GenerateRSAKeyPair(nil, 2048) + require.Error(t, err) + + val := randomBytes() + + _, err = ctx.GenerateRSAKeyPairWithLabel(nil, val, 2048) + require.Error(t, err) + + _, err = ctx.GenerateRSAKeyPairWithLabel(val, nil, 2048) + require.Error(t, err) +} diff --git a/sessions.go b/sessions.go index 4d235bc..26e5a53 100644 --- a/sessions.go +++ b/sessions.go @@ -23,193 +23,63 @@ package crypto11 import ( "context" - "errors" - "fmt" + "github.com/miekg/pkcs11" - "github.com/youtube/vitess/go/pools" - "log" - "sync" + "github.com/vitessio/vitess/go/pools" ) -// PKCS11Session is a pair of PKCS#11 context and a reference to a loaded session handle. -type PKCS11Session struct { - Ctx *pkcs11.Ctx - Handle pkcs11.SessionHandle +// pkcs11Session wraps a PKCS#11 session handle so we can use it in a resource pool. +type pkcs11Session struct { + ctx *pkcs11.Ctx + handle pkcs11.SessionHandle } -// sessionPool is a thread safe pool of PKCS#11 sessions -type sessionPool struct { - m sync.RWMutex - pool map[uint]*pools.ResourcePool +// Close is required to satisfy the pools.Resource interface. It closes the session, but swallows any +// errors that occur. +func (s pkcs11Session) Close() { + // We cannot return an error, so we swallow it + _ = s.ctx.CloseSession(s.handle) } -// Map of slot IDs to session pools -var pool = newSessionPool() - -// Error specifies an event when the requested slot is already set in the sessions pool -var errSlotBusy = errors.New("pool slot busy") - -// Error when there is no pool at specific slot in the sessions pool -var errPoolNotFound = errors.New("pool not found") - -// Create a new session for a given slot -func newSession(ctx *pkcs11.Ctx, slot uint) (*PKCS11Session, error) { - session, err := ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) +// withSession executes a function with a session. +func (c *Context) withSession(f func(session *pkcs11Session) error) error { + session, err := c.getSession() if err != nil { - return nil, err - } - return &PKCS11Session{ctx, session}, nil -} - -// Create a new session pool with default configuration -func newSessionPool() *sessionPool { - return &sessionPool{ - pool: map[uint]*pools.ResourcePool{}, + return err } -} - -// Close closes the session. -// -// Deprecated: Use CloseSession, which returns any underlying errors. -func (session *PKCS11Session) Close() { - // TODO - when next making breaking changes, kill this method (or fix it) + defer c.pool.Put(session) - // Assign error to "_", to indicate we are knowingly ignoring it - _ = session.Ctx.CloseSession(session.Handle) + return f(session) } -// CloseSession closes the session. -func (session *PKCS11Session) CloseSession() error { - return session.Ctx.CloseSession(session.Handle) -} - -// Get returns requested resource pool by slot id -func (p *sessionPool) Get(slot uint) *pools.ResourcePool { - p.m.RLock() - defer p.m.RUnlock() - return p.pool[slot] -} - -// Put stores new resource pool into the pool if the requested slot is free -func (p *sessionPool) PutIfAbsent(slot uint, pool *pools.ResourcePool) error { - p.m.Lock() - defer p.m.Unlock() - if _, ok := p.pool[slot]; ok { - return errSlotBusy - } - p.pool[slot] = pool - return nil -} - -// Run a function with a session -// -// setupSessions must have been called for the slot already, otherwise -// an error will be returned. -func withSession(slot uint, f func(session *PKCS11Session) error) error { - sessionPool := pool.Get(slot) - if sessionPool == nil { - return fmt.Errorf("crypto11: no session for slot %d", slot) - } - +// getSession retrieves a session from the pool, respecting the timeout defined in the Context config. +// Callers are responsible for putting this session back in the pool. +func (c *Context) getSession() (*pkcs11Session, error) { ctx := context.Background() - if instance.cfg.PoolWaitTimeout > 0 { + + if c.cfg.PoolWaitTimeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(context.Background(), instance.cfg.PoolWaitTimeout) + ctx, cancel = context.WithTimeout(context.Background(), c.cfg.PoolWaitTimeout) defer cancel() } - session, err := sessionPool.Get(ctx) - if err != nil { - return err + resource, err := c.pool.Get(ctx) + if err == pools.ErrClosed { + // Our Context must have been closed, return a nicer error + return nil, errClosed } - defer sessionPool.Put(session) - - s := session.(*PKCS11Session) - err = f(s) if err != nil { - // if a request required login, then try to login - if perr, ok := err.(pkcs11.Error); ok && perr == pkcs11.CKR_USER_NOT_LOGGED_IN && instance.cfg.Pin != "" { - if err = s.Ctx.Login(s.Handle, pkcs11.CKU_USER, instance.cfg.Pin); err != nil { - return err - } - // retry after login - return f(s) - } - - return err + return nil, err } - return nil -} - -// Ensures that sessions are setup. -func ensureSessions(ctx *libCtx, slot uint) error { - if err := setupSessions(ctx, slot); err != nil && err != errSlotBusy { - return err - } - return nil + return resource.(*pkcs11Session), nil } -// Create the session pool for a given slot if it does not exist -// already. -func setupSessions(c *libCtx, slot uint) error { - return pool.PutIfAbsent(slot, pools.NewResourcePool( - func() (pools.Resource, error) { - s, err := newSession(c.ctx, slot) - if err != nil { - return nil, err - } - - if instance.token.Flags&pkcs11.CKF_LOGIN_REQUIRED != 0 && instance.cfg.Pin != "" { - // login required if a pool evict idle sessions or - // for the first connection in the pool (handled in lib conf) - if instance.cfg.IdleTimeout > 0 { - if err = loginToken(s); err != nil { - return nil, err - } - } - } - - return s, nil - }, - c.cfg.MaxSessions, - c.cfg.MaxSessions, - c.cfg.IdleTimeout, - )) -} - -func loginToken(s *PKCS11Session) error { - // login is pkcs11 context wide, not just handle/session scoped - err := s.Ctx.Login(s.Handle, pkcs11.CKU_USER, instance.cfg.Pin) +// resourcePoolFactoryFunc is called by the resource pool when a new session is needed. +func (c *Context) resourcePoolFactoryFunc() (pools.Resource, error) { + session, err := c.ctx.OpenSession(c.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) if err != nil { - if code, ok := err.(pkcs11.Error); ok && code == pkcs11.CKR_USER_ALREADY_LOGGED_IN { - return nil - } - log.Printf("Failed to open PKCS#11 Session: %s", err.Error()) - - closeErr := s.CloseSession() - if closeErr != nil { - log.Printf("Failed to close session: %s", closeErr.Error()) - } - - // Return the first error we encountered - return err - } - return nil -} - -// Releases a sessions specific to the requested slot if present. -func (p *sessionPool) closeSessions(slot uint) error { - p.m.Lock() - defer p.m.Unlock() - - rp, ok := p.pool[slot] - if !ok { - return errPoolNotFound + return nil, err } - - rp.Close() - delete(p.pool, slot) - - return nil + return &pkcs11Session{c.ctx, session}, nil } diff --git a/symmetric.go b/symmetric.go index 668f1d4..c1774ee 100644 --- a/symmetric.go +++ b/symmetric.go @@ -22,6 +22,8 @@ package crypto11 import ( + "errors" + "github.com/miekg/pkcs11" ) @@ -64,7 +66,7 @@ type SymmetricCipher struct { // CipherAES describes the AES cipher. Use this with the // GenerateSecretKey... functions. -var CipherAES = SymmetricCipher{ +var CipherAES = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_AES, @@ -82,7 +84,7 @@ var CipherAES = SymmetricCipher{ // CipherDES3 describes the three-key triple-DES cipher. Use this with the // GenerateSecretKey... functions. -var CipherDES3 = SymmetricCipher{ +var CipherDES3 = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_DES3, @@ -104,7 +106,7 @@ var CipherDES3 = SymmetricCipher{ // The spec promises that this mechanism can be used to perform HMAC // operations, although implementations vary; // CipherHMACSHA1 and so on may give better results. -var CipherGeneric = SymmetricCipher{ +var CipherGeneric = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_GENERIC_SECRET, @@ -121,7 +123,7 @@ var CipherGeneric = SymmetricCipher{ // CipherHMACSHA1 describes the CKK_SHA_1_HMAC key type. Use this with the // GenerateSecretKey... functions. -var CipherHMACSHA1 = SymmetricCipher{ +var CipherHMACSHA1 = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_SHA_1_HMAC, @@ -142,7 +144,7 @@ var CipherHMACSHA1 = SymmetricCipher{ // CipherHMACSHA224 describes the CKK_SHA224_HMAC key type. Use this with the // GenerateSecretKey... functions. -var CipherHMACSHA224 = SymmetricCipher{ +var CipherHMACSHA224 = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_SHA224_HMAC, @@ -163,7 +165,7 @@ var CipherHMACSHA224 = SymmetricCipher{ // CipherHMACSHA256 describes the CKK_SHA256_HMAC key type. Use this with the // GenerateSecretKey... functions. -var CipherHMACSHA256 = SymmetricCipher{ +var CipherHMACSHA256 = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_SHA256_HMAC, @@ -184,7 +186,7 @@ var CipherHMACSHA256 = SymmetricCipher{ // CipherHMACSHA384 describes the CKK_SHA384_HMAC key type. Use this with the // GenerateSecretKey... functions. -var CipherHMACSHA384 = SymmetricCipher{ +var CipherHMACSHA384 = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_SHA384_HMAC, @@ -205,7 +207,7 @@ var CipherHMACSHA384 = SymmetricCipher{ // CipherHMACSHA512 describes the CKK_SHA512_HMAC key type. Use this with the // GenerateSecretKey... functions. -var CipherHMACSHA512 = SymmetricCipher{ +var CipherHMACSHA512 = &SymmetricCipher{ GenParams: []SymmetricGenParams{ { KeyType: pkcs11.CKK_SHA512_HMAC, @@ -226,108 +228,115 @@ var CipherHMACSHA512 = SymmetricCipher{ // Ciphers is a map of PKCS#11 key types (CKK_...) to symmetric cipher information. var Ciphers = map[int]*SymmetricCipher{ - pkcs11.CKK_AES: &CipherAES, - pkcs11.CKK_DES3: &CipherDES3, - pkcs11.CKK_GENERIC_SECRET: &CipherGeneric, - pkcs11.CKK_SHA_1_HMAC: &CipherHMACSHA1, - pkcs11.CKK_SHA224_HMAC: &CipherHMACSHA224, - pkcs11.CKK_SHA256_HMAC: &CipherHMACSHA256, - pkcs11.CKK_SHA384_HMAC: &CipherHMACSHA384, - pkcs11.CKK_SHA512_HMAC: &CipherHMACSHA512, + pkcs11.CKK_AES: CipherAES, + pkcs11.CKK_DES3: CipherDES3, + pkcs11.CKK_GENERIC_SECRET: CipherGeneric, + pkcs11.CKK_SHA_1_HMAC: CipherHMACSHA1, + pkcs11.CKK_SHA224_HMAC: CipherHMACSHA224, + pkcs11.CKK_SHA256_HMAC: CipherHMACSHA256, + pkcs11.CKK_SHA384_HMAC: CipherHMACSHA384, + pkcs11.CKK_SHA512_HMAC: CipherHMACSHA512, } -// PKCS11SecretKey contains a reference to a loaded PKCS#11 symmetric key object. +// SecretKey contains a reference to a loaded PKCS#11 symmetric key object. // -// A *PKCS11SecretKey implements the cipher.Block interface, allowing it be used +// A *SecretKey implements the cipher.Block interface, allowing it be used // as the argument to cipher.NewCBCEncrypter and similar methods. // For bulk operation this is very inefficient; // using NewCBCEncrypterCloser, NewCBCEncrypter or NewCBC from this package is // much faster. -type PKCS11SecretKey struct { - PKCS11Object +type SecretKey struct { + pkcs11Object // Symmetric cipher information Cipher *SymmetricCipher } -// Key generation ------------------------------------------------------------- - -// GenerateSecretKey creates an secret key of given length and type. -// -// The key will have a random label and ID. -func GenerateSecretKey(bits int, cipher *SymmetricCipher) (*PKCS11SecretKey, error) { - return GenerateSecretKeyOnSlot(instance.slot, nil, nil, bits, cipher) -} +// GenerateSecretKey creates an secret key of given length and type. The id parameter is used to +// set CKA_ID and must be non-nil. +func (c *Context) GenerateSecretKey(id []byte, bits int, cipher *SymmetricCipher) (*SecretKey, error) { + if c.closed.Get() { + return nil, errClosed + } -// GenerateSecretKeyOnSlot creates as symmetric key on a specified slot -// -// Either or both label and/or id can be nil, in which case random values will be generated. -func GenerateSecretKeyOnSlot(slot uint, id []byte, label []byte, bits int, cipher *SymmetricCipher) (*PKCS11SecretKey, error) { - var k *PKCS11SecretKey - var err error - if err = ensureSessions(instance, slot); err != nil { + if err := notNilBytes(id, "id"); err != nil { return nil, err } - err = withSession(slot, func(session *PKCS11Session) error { - k, err = GenerateSecretKeyOnSession(session, slot, id, label, bits, cipher) - return err - }) - return k, err + + return c.generateSecretKey(id, nil, bits, cipher) } -// GenerateSecretKeyOnSession creates a symmetric key of given type and -// length, on a specified session. -// -// Either or both label and/or id can be nil, in which case random values will be generated. -func GenerateSecretKeyOnSession(session *PKCS11Session, slot uint, id []byte, label []byte, bits int, cipher *SymmetricCipher) (key *PKCS11SecretKey, err error) { - // TODO refactor with the other key generation implementations - if label == nil { - if label, err = generateKeyLabel(); err != nil { - return nil, err - } - } - if id == nil { - if id, err = generateKeyLabel(); err != nil { - return nil, err - } +// GenerateSecretKey creates an secret key of given length and type. The id and label parameters are used to +// set CKA_ID and CKA_LABEL respectively and must be non-nil. +func (c *Context) GenerateSecretKeyWithLabel(id, label []byte, bits int, cipher *SymmetricCipher) (*SecretKey, error) { + if c.closed.Get() { + return nil, errClosed } - var privHandle pkcs11.ObjectHandle - // CKK_*_HMAC exists but there is no specific corresponding CKM_*_KEY_GEN - // mechanism. Therefore we attempt both CKM_GENERIC_SECRET_KEY_GEN and - // vendor-specific mechanisms. - for _, genMech := range cipher.GenParams { - secretKeyTemplate := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, genMech.KeyType), - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_SIGN, cipher.MAC), - pkcs11.NewAttribute(pkcs11.CKA_VERIFY, cipher.MAC), - pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, cipher.Encrypt), - pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, cipher.Encrypt), - pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), - pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), - pkcs11.NewAttribute(pkcs11.CKA_ID, id), - } - if bits > 0 { - secretKeyTemplate = append(secretKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, bits/8)) - } - mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(genMech.GenMech, nil)} - privHandle, err = session.Ctx.GenerateKey(session.Handle, mech, secretKeyTemplate) - if err == nil { - break - } - // nShield returns this if if doesn't like the CKK/CKM combination. - if e, ok := err.(pkcs11.Error); ok && e == pkcs11.CKR_TEMPLATE_INCONSISTENT { - continue - } - if err != nil { - return - } + + if err := notNilBytes(id, "id"); err != nil { + return nil, err } - if err != nil { - return + if err := notNilBytes(label, "label"); err != nil { + return nil, err } - key = &PKCS11SecretKey{PKCS11Object{privHandle, slot}, cipher} + + return c.generateSecretKey(id, label, bits, cipher) + +} + +// generateSecretKey creates an secret key of given length and type. +func (c *Context) generateSecretKey(id, label []byte, bits int, cipher *SymmetricCipher) (k *SecretKey, err error) { + err = c.withSession(func(session *pkcs11Session) error { + + // CKK_*_HMAC exists but there is no specific corresponding CKM_*_KEY_GEN + // mechanism. Therefore we attempt both CKM_GENERIC_SECRET_KEY_GEN and + // vendor-specific mechanisms. + for _, genMech := range cipher.GenParams { + secretKeyTemplate := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), + pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, genMech.KeyType), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_SIGN, cipher.MAC), + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, cipher.MAC), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, cipher.Encrypt), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, cipher.Encrypt), + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), + } + + if id != nil { + secretKeyTemplate = append(secretKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + } + if label != nil { + secretKeyTemplate = append(secretKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + } + + if bits > 0 { + secretKeyTemplate = append(secretKeyTemplate, pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, bits/8)) + } + mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(genMech.GenMech, nil)} + + privHandle, err := session.ctx.GenerateKey(session.handle, mech, secretKeyTemplate) + if err == nil { + k = &SecretKey{pkcs11Object{privHandle, c}, cipher} + return nil + } + + // nShield returns this if if doesn't like the CKK/CKM combination. + if e, ok := err.(pkcs11.Error); ok && e == pkcs11.CKR_TEMPLATE_INCONSISTENT { + continue + } + + return err + } + + // We can only get here if there were no GenParams + return errors.New("cipher must have GenParams") + }) return } + +// Delete deletes the secret key from the token. +func (key *SecretKey) Delete() error { + return key.pkcs11Object.Delete() +} diff --git a/symmetric_test.go b/symmetric_test.go index c4682da..7f06d3b 100644 --- a/symmetric_test.go +++ b/symmetric_test.go @@ -23,120 +23,98 @@ package crypto11 import ( "bytes" - "crypto" "crypto/cipher" - "github.com/miekg/pkcs11" - "github.com/stretchr/testify/require" "runtime" "testing" + + "github.com/miekg/pkcs11" + "github.com/stretchr/testify/require" ) func TestHardSymmetric(t *testing.T) { - _, err := ConfigureFromFile("config") + ctx, err := ConfigureFromFile("config") require.NoError(t, err) - t.Run("AES128", func(t *testing.T) { testHardSymmetric(t, pkcs11.CKK_AES, 128) }) - t.Run("AES192", func(t *testing.T) { testHardSymmetric(t, pkcs11.CKK_AES, 192) }) - t.Run("AES256", func(t *testing.T) { testHardSymmetric(t, pkcs11.CKK_AES, 256) }) - t.Run("DES3", func(t *testing.T) { testHardSymmetric(t, pkcs11.CKK_DES3, 0) }) - require.NoError(t, Close()) + defer func() { + require.NoError(t, ctx.Close()) + }() + + t.Run("AES128", func(t *testing.T) { testHardSymmetric(t, ctx, pkcs11.CKK_AES, 128) }) + t.Run("AES192", func(t *testing.T) { testHardSymmetric(t, ctx, pkcs11.CKK_AES, 192) }) + t.Run("AES256", func(t *testing.T) { testHardSymmetric(t, ctx, pkcs11.CKK_AES, 256) }) + t.Run("DES3", func(t *testing.T) { testHardSymmetric(t, ctx, pkcs11.CKK_DES3, 0) }) } -func testHardSymmetric(t *testing.T, keytype int, bits int) { - var err error - var key, key2 *PKCS11SecretKey - var id []byte - t.Run("Generate", func(t *testing.T) { - if key, err = GenerateSecretKey(bits, Ciphers[keytype]); err != nil { - t.Errorf("crypto11.GenerateSecretKey: %v", err) - return - } - if key == nil { - t.Errorf("crypto11.GenerateSecretKey: returned nil but no error") - return - } - if id, _, err = key.Identify(); err != nil { - t.Errorf("crypto11.PKCS11SecretKey.Identify: %v", err) - return - } - }) - var key2gen crypto.PrivateKey +func testHardSymmetric(t *testing.T, ctx *Context, keytype int, bits int) { + + id := randomBytes() + key, err := ctx.GenerateSecretKey(id, bits, Ciphers[keytype]) + require.NoError(t, err) + require.NotNil(t, key) + + var key2 *SecretKey t.Run("Find", func(t *testing.T) { - if key2gen, err = FindKey(id, nil); err != nil { - t.Errorf("crypto11.FindKey by id: %v", err) - return - } - key2 = key2gen.(*PKCS11SecretKey) + key2, err = ctx.FindKey(id, nil) + require.NoError(t, err) }) + t.Run("Block", func(t *testing.T) { testSymmetricBlock(t, key, key2) }) + iv := make([]byte, key.BlockSize()) - for i := 0; i < len(iv); i++ { + for i := range iv { iv[i] = 0xF0 } + t.Run("CBC", func(t *testing.T) { testSymmetricMode(t, cipher.NewCBCEncrypter(key2, iv), cipher.NewCBCDecrypter(key2, iv)) }) + t.Run("CBCClose", func(t *testing.T) { - var enc, dec BlockModeCloser - if enc, err = key2.NewCBCEncrypterCloser(iv); err != nil { - t.Errorf("NewCBCEncrypter: %v", err) - return - } - if dec, err = key2.NewCBCDecrypterCloser(iv); err != nil { - t.Errorf("NewCBCDecrypter: %v", err) - return - } + + enc, err := key2.NewCBCEncrypterCloser(iv) + require.NoError(t, err) + + dec, err := key2.NewCBCDecrypterCloser(iv) + require.NoError(t, err) + testSymmetricMode(t, enc, dec) enc.Close() dec.Close() }) + t.Run("CBCNoClose", func(t *testing.T) { - var enc, dec cipher.BlockMode - if enc, err = key2.NewCBCEncrypter(iv); err != nil { - t.Errorf("NewCBCEncrypter: %v", err) - return - } - if dec, err = key2.NewCBCDecrypter(iv); err != nil { - t.Errorf("NewCBCDecrypter: %v", err) - return - } + enc, err := key2.NewCBCEncrypter(iv) + require.NoError(t, err) + + dec, err := key2.NewCBCDecrypter(iv) + require.NoError(t, err) testSymmetricMode(t, enc, dec) // See discussion at BlockModeCloser. runtime.GC() }) + t.Run("CBCSealOpen", func(t *testing.T) { aead, err := key2.NewCBC(PaddingNone) - if err != nil { - t.Errorf("cipher.NewCBC: %v", err) - return - } + require.NoError(t, err) testAEADMode(t, aead, 128, 0) }) + t.Run("CBCPKCSSealOpen", func(t *testing.T) { aead, err := key2.NewCBC(PaddingPKCS) - if err != nil { - t.Errorf("cipher.NewCBC: %v", err) - return - } + require.NoError(t, err) testAEADMode(t, aead, 127, 0) }) if bits == 128 { t.Run("GCMSoft", func(t *testing.T) { aead, err := cipher.NewGCM(key2) - if err != nil { - t.Errorf("cipher.NewGCM: %v", err) - return - } + require.NoError(t, err) testAEADMode(t, aead, 127, 129) }) t.Run("GCMHard", func(t *testing.T) { aead, err := key2.NewGCM() - if err != nil { - t.Errorf("key2.NewGCM: %v", err) - return - } - needMechanism(t, key2.Slot, pkcs11.CKM_AES_GCM) + require.NoError(t, err) + skipIfMechUnsupported(t, key2.context, pkcs11.CKM_AES_GCM) testAEADMode(t, aead, 127, 129) }) // TODO check that hard/soft is consistent! @@ -159,7 +137,7 @@ func testSymmetricBlock(t *testing.T, encryptKey cipher.Block, decryptKey cipher output[i] = byte(i + 6*b) } encryptKey.Encrypt(middle, input) // middle[:b] = encrypt(input[:b]) - if bytes.Compare(input[:b], middle[:b]) == 0 { + if bytes.Equal(input[:b], middle[:b]) { t.Errorf("crypto11.PKCSSecretKey.Encrypt: identity transformation") return } @@ -184,7 +162,7 @@ func testSymmetricBlock(t *testing.T, encryptKey cipher.Block, decryptKey cipher } } decryptKey.Decrypt(output, middle) // output[:b] = decrypt(middle[:b]) - if bytes.Compare(input[:b], output[:b]) != 0 { + if !bytes.Equal(input[:b], output[:b]) { t.Errorf("crypto11.PKCSSecretKey.Decrypt: plaintext wrong") return } @@ -208,7 +186,7 @@ func testSymmetricMode(t *testing.T, encrypt cipher.BlockMode, decrypt cipher.Bl } // Encrypt the first 128 bytes encrypt.CryptBlocks(middle, input[:128]) - if bytes.Compare(input[:128], middle[:128]) == 0 { + if bytes.Equal(input[:128], middle[:128]) { t.Errorf("BlockMode.Encrypt: did not modify destination") return } @@ -228,7 +206,7 @@ func testSymmetricMode(t *testing.T, encrypt cipher.BlockMode, decrypt cipher.Bl encrypt.CryptBlocks(middle[128:], input[128:]) // Decrypt in a single go decrypt.CryptBlocks(output, middle) - if bytes.Compare(input, output) != 0 { + if !bytes.Equal(input, output) { t.Errorf("BlockMode.Decrypt: plaintext wrong") return } @@ -250,30 +228,35 @@ func testAEADMode(t *testing.T, aead cipher.AEAD, ptlen int, adlen int) { t.Errorf("aead.Open: %s", err) return } - if bytes.Compare(plaintext, decrypted) != 0 { + if !bytes.Equal(plaintext, decrypted) { t.Errorf("aead.Open: mismatch") return } } func BenchmarkCBC(b *testing.B) { - _, err := ConfigureFromFile("config") + ctx, err := ConfigureFromFile("config") + require.NoError(b, err) + + defer func() { + require.NoError(b, ctx.Close()) + }() + + id := randomBytes() + key, err := ctx.GenerateSecretKey(id, 128, Ciphers[pkcs11.CKK_AES]) require.NoError(b, err) - var key *PKCS11SecretKey - if key, err = GenerateSecretKey(128, Ciphers[pkcs11.CKK_AES]); err != nil { - b.Errorf("crypto11.GenerateSecretKey: %v", err) - return - } iv := make([]byte, 16) plaintext := make([]byte, 65536) ciphertext := make([]byte, 65536) + b.Run("Native", func(b *testing.B) { for i := 0; i < b.N; i++ { mode := cipher.NewCBCEncrypter(key, iv) mode.CryptBlocks(ciphertext, plaintext) } }) + b.Run("IdiomaticClose", func(b *testing.B) { for i := 0; i < b.N; i++ { mode, err := key.NewCBCEncrypterCloser(iv) @@ -284,6 +267,7 @@ func BenchmarkCBC(b *testing.B) { mode.Close() } }) + b.Run("Idiomatic", func(b *testing.B) { for i := 0; i < b.N; i++ { mode, err := key.NewCBCEncrypter(iv) @@ -294,7 +278,26 @@ func BenchmarkCBC(b *testing.B) { } runtime.GC() }) - require.NoError(b, Close()) +} + +func TestSymmetricRequiredArgs(t *testing.T) { + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + _, err = ctx.GenerateSecretKey(nil, 128, CipherAES) + require.Error(t, err) + + val := randomBytes() + + _, err = ctx.GenerateSecretKeyWithLabel(nil, val, 128, CipherAES) + require.Error(t, err) + + _, err = ctx.GenerateSecretKeyWithLabel(val, nil, 128, CipherAES) + require.Error(t, err) } // TODO BenchmarkGCM along the same lines as above diff --git a/thread_test.go b/thread_test.go index bd01fb8..280690a 100644 --- a/thread_test.go +++ b/thread_test.go @@ -23,32 +23,39 @@ package crypto11 import ( "crypto" - "github.com/stretchr/testify/require" "testing" "time" + + "github.com/stretchr/testify/require" ) var threadCount = 32 var signaturesPerThread = 256 func TestThreadedRSA(t *testing.T) { - var key *PKCS11PrivateKeyRSA - _, err := ConfigureFromFile("config") + + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + id := randomBytes() + key, err := ctx.GenerateRSAKeyPair(id, 1024) require.NoError(t, err) - if key, err = GenerateRSAKeyPair(1024); err != nil { - t.Errorf("crypto11.GenerateRSAKeyPair: %v", err) - return - } done := make(chan int) started := time.Now() + t.Logf("Starting %v threads", threadCount) + for i := 0; i < threadCount; i++ { go signingRoutine(t, key, done) } t.Logf("Waiting for %v threads", threadCount) for i := 0; i < threadCount; i++ { - _ = <-done + <-done } finished := time.Now() ticks := finished.Sub(started) @@ -56,7 +63,6 @@ func TestThreadedRSA(t *testing.T) { t.Logf("Made %v signatures in %v elapsed (%v/s)", threadCount*signaturesPerThread, elapsed, float64(threadCount*signaturesPerThread)/elapsed) - require.NoError(t, Close()) } func signingRoutine(t *testing.T, key crypto.Signer, done chan int) {