Skip to content

Commit

Permalink
Respect nbf and exp in local encrypt/wrap operations (#11953)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jun 15, 2020
1 parent 935b7ad commit 743dea5
Show file tree
Hide file tree
Showing 8 changed files with 798 additions and 45 deletions.
3 changes: 2 additions & 1 deletion sdk/keyvault/azure-keyvault-keys/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Release History

## 4.2.0b2 (Unreleased)

- `CryptographyClient` will no longer perform encrypt or wrap operations when
its key has expired or is not yet valid.

## 4.2.0b1 (2020-03-10)
- Support for Key Vault API version 7.1-preview
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from datetime import datetime, timedelta, tzinfo

import six
from azure.core.exceptions import AzureError, HttpResponseError
from azure.core.tracing.decorator import distributed_trace
Expand All @@ -24,6 +26,43 @@
from ._internal import Key as _Key


class _UTC_TZ(tzinfo):
"""from https://docs.python.org/2/library/datetime.html#tzinfo-objects"""

ZERO = timedelta(0)

def utcoffset(self, dt):
return self.ZERO

def tzname(self, dt):
return "UTC"

def dst(self, dt):
return self.ZERO


_UTC = _UTC_TZ()


def _enforce_nbf_exp(key):
# type: (KeyVaultKey) -> None
try:
nbf = key.properties.not_before
exp = key.properties.expires_on
except AttributeError:
# we consider the key valid because a user must have deliberately created it
# (if it came from Key Vault, it would have those attributes)
return

now = datetime.now(_UTC)
if (nbf and exp) and not nbf <= now <= exp:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(nbf, exp))
if nbf and nbf >= now:
raise ValueError("This client's key is not useable until {} (UTC)".format(nbf))
if exp and exp <= now:
raise ValueError("This client's key expired at {} (UTC)".format(exp))


class CryptographyClient(KeyVaultClientBase):
"""Performs cryptographic operations using Azure Key Vault keys.
Expand Down Expand Up @@ -80,9 +119,7 @@ def __init__(self, key, credential, **kwargs):

self._internal_key = None # type: Optional[_Key]

super(CryptographyClient, self).__init__(
vault_url=self._key_id.vault_url, credential=credential, **kwargs
)
super(CryptographyClient, self).__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)

@property
def key_id(self):
Expand Down Expand Up @@ -116,7 +153,7 @@ def _get_key(self, **kwargs):
return self._key

def _get_local_key(self, **kwargs):
# type: () -> Optional[_Key]
# type: (**Any) -> Optional[_Key]
"""Gets an object implementing local operations. Will be ``None``, if the client was instantiated with a key
id and lacks keys/get permission."""

Expand All @@ -140,7 +177,6 @@ def _get_local_key(self, **kwargs):
@distributed_trace
def encrypt(self, algorithm, plaintext, **kwargs):
# type: (EncryptionAlgorithm, bytes, **Any) -> EncryptResult
# pylint:disable=line-too-long
"""Encrypt bytes using the client's key. Requires the keys/encrypt permission.
This method encrypts only a single block of data, whose size depends on the key and encryption algorithm.
Expand All @@ -166,6 +202,7 @@ def encrypt(self, algorithm, plaintext, **kwargs):

local_key = self._get_local_key(**kwargs)
if local_key:
_enforce_nbf_exp(self._key)
if "encrypt" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/encrypt' permission")
result = local_key.encrypt(plaintext, algorithm=algorithm.value)
Expand Down Expand Up @@ -238,6 +275,7 @@ def wrap_key(self, algorithm, key, **kwargs):

local_key = self._get_local_key(**kwargs)
if local_key:
_enforce_nbf_exp(self._key)
if "wrapKey" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/wrapKey' permission")
result = local_key.wrap_key(key, algorithm=algorithm.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import codecs
import uuid

from cryptography.exceptions import InvalidSignature
Expand Down Expand Up @@ -170,28 +169,28 @@ def default_signature_algorithm(self):

def encrypt(self, plain_text, **kwargs):
algorithm = self._get_algorithm("encrypt", **kwargs)
encryptor = algorithm.create_encryptor(self._rsa_impl)
encryptor = algorithm.create_encryptor(self.public_key)
return encryptor.transform(plain_text)

def decrypt(self, cipher_text, **kwargs):
if not self.is_private_key():
raise NotImplementedError("The current RsaKey does not support decrypt")

algorithm = self._get_algorithm("decrypt", **kwargs)
decryptor = algorithm.create_decryptor(self._rsa_impl)
decryptor = algorithm.create_decryptor(self.private_key)
return decryptor.transform(cipher_text)

def sign(self, digest, **kwargs):
if not self.is_private_key():
raise NotImplementedError("The current RsaKey does not support sign")

algorithm = self._get_algorithm("sign", **kwargs)
signer = algorithm.create_signature_transform(self._rsa_impl)
signer = algorithm.create_signature_transform(self.private_key)
return signer.sign(digest)

def verify(self, digest, signature, **kwargs):
algorithm = self._get_algorithm("verify", **kwargs)
signer = algorithm.create_signature_transform(self._rsa_impl)
signer = algorithm.create_signature_transform(self.public_key)
try:
# cryptography's verify methods return None, and raise when verification fails
signer.verify(digest, signature)
Expand All @@ -201,15 +200,15 @@ def verify(self, digest, signature, **kwargs):

def wrap_key(self, key, **kwargs):
algorithm = self._get_algorithm("wrapKey", **kwargs)
encryptor = algorithm.create_encryptor(self._rsa_impl)
encryptor = algorithm.create_encryptor(self.public_key)
return encryptor.transform(key)

def unwrap_key(self, encrypted_key, **kwargs):
if not self.is_private_key():
raise NotImplementedError("The current RsaKey does not support unwrap")

algorithm = self._get_algorithm("unwrapKey", **kwargs)
decryptor = algorithm.create_decryptor(self._rsa_impl)
decryptor = algorithm.create_decryptor(self.private_key)
return decryptor.transform(encrypted_key)

def is_private_key(self):
Expand All @@ -223,20 +222,3 @@ def _public_key_material(self):

def _private_key_material(self):
return self.private_key.private_numbers() if self.private_key else None


def _bytes_to_int(b):
return int(codecs.encode(b, "hex"), 16)


def _int_to_bytes(i):
h = hex(i)
if len(h) > 1 and h[0:2] == "0x":
h = h[2:]

# need to strip L in python 2.x
h = h.strip("L")

if len(h) % 2:
h = "0" + h
return codecs.decode(h, "hex")
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# ------------------------------------
from azure.core.exceptions import AzureError, HttpResponseError
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.keyvault.keys._shared import AsyncKeyVaultClientBase, parse_vault_id

from .. import DecryptResult, EncryptResult, SignResult, VerifyResult, UnwrapResult, WrapResult
from .._internal import EllipticCurveKey, RsaKey, SymmetricKey
from ...crypto._client import _enforce_nbf_exp
from ..._models import KeyVaultKey
from ..._shared import AsyncKeyVaultClientBase, parse_vault_id

try:
from typing import TYPE_CHECKING
Expand All @@ -18,7 +19,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Optional, Union
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
from .. import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from .._internal import Key as _Key

Expand Down Expand Up @@ -57,7 +58,7 @@ class CryptographyClient(AsyncKeyVaultClientBase):
"""

def __init__(self, key: "Union[KeyVaultKey, str]", credential: "TokenCredential", **kwargs: "Any") -> None:
def __init__(self, key: "Union[KeyVaultKey, str]", credential: "AsyncTokenCredential", **kwargs: "Any") -> None:
if isinstance(key, KeyVaultKey):
self._key = key
self._key_id = parse_vault_id(key.id)
Expand All @@ -77,9 +78,7 @@ def __init__(self, key: "Union[KeyVaultKey, str]", credential: "TokenCredential"

self._internal_key = None # type: Optional[_Key]

super(CryptographyClient, self).__init__(
vault_url=self._key_id.vault_url, credential=credential, **kwargs
)
super().__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)

@property
def key_id(self) -> str:
Expand Down Expand Up @@ -158,6 +157,7 @@ async def encrypt(self, algorithm: "EncryptionAlgorithm", plaintext: bytes, **kw

local_key = await self._get_local_key(**kwargs)
if local_key:
_enforce_nbf_exp(self._key)
if "encrypt" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/encrypt' permission")
result = local_key.encrypt(plaintext, algorithm=algorithm.value)
Expand Down Expand Up @@ -224,6 +224,7 @@ async def wrap_key(self, algorithm: "KeyWrapAlgorithm", key: bytes, **kwargs: "A

local_key = await self._get_local_key(**kwargs)
if local_key:
_enforce_nbf_exp(self._key)
if "wrapKey" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/wrapKey' permission")
result = local_key.wrap_key(key, algorithm=algorithm.value)
Expand Down
Loading

0 comments on commit 743dea5

Please sign in to comment.