diff --git a/crypto/olm/pk.go b/crypto/olm/pk.go index 1e628745..e441ba14 100644 --- a/crypto/olm/pk.go +++ b/crypto/olm/pk.go @@ -109,3 +109,62 @@ func (p *PkSigning) SignJSON(obj interface{}) (string, error) { func (p *PkSigning) lastError() error { return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) } + +type PkDecryption struct { + int *C.OlmPkDecryption + mem []byte + PublicKey []byte +} + +func pkDecryptionSize() uint { + return uint(C.olm_pk_decryption_size()) +} + +func pkDecryptionPublicKeySize() uint { + return uint(C.olm_pk_key_length()) +} + +func NewPkDecryption(privateKey []byte) (*PkDecryption, error) { + memory := make([]byte, pkDecryptionSize()) + p := &PkDecryption{ + int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), + mem: memory, + } + p.Clear() + pubKey := make([]byte, pkDecryptionPublicKeySize()) + + if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), + unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { + return nil, p.lastError() + } + p.PublicKey = pubKey + + return p, nil +} + +func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { + maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) + plaintext := make([]byte, maxPlaintextLength) + + size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)), + unsafe.Pointer(&mac[0]), C.size_t(len(mac)), + unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)), + unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) + if size == errorVal() { + return nil, p.lastError() + } + + return plaintext[:size], nil +} + +// Clear clears the underlying memory of a PkDecryption object. +func (p *PkDecryption) Clear() { + C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int)) +} + +// lastError returns the last error that happened in relation to this PkDecryption object. +func (p *PkDecryption) lastError() error { + return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int)))) +}