428 lines
13 KiB
Python
428 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright 2017 Gehirn Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import hmac
|
|
from warnings import warn
|
|
from abc import (
|
|
ABC,
|
|
abstractmethod,
|
|
)
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Mapping,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
Optional
|
|
)
|
|
from functools import wraps
|
|
|
|
import cryptography.hazmat.primitives.serialization as serialization_module
|
|
from cryptography.exceptions import InvalidSignature
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
|
rsa_crt_dmp1,
|
|
rsa_crt_dmq1,
|
|
rsa_crt_iqmp,
|
|
rsa_recover_prime_factors,
|
|
RSAPrivateKey,
|
|
RSAPrivateNumbers,
|
|
RSAPublicKey,
|
|
RSAPublicNumbers,
|
|
)
|
|
from cryptography.hazmat.primitives.hashes import HashAlgorithm
|
|
|
|
from .exceptions import (
|
|
MalformedJWKError,
|
|
UnsupportedKeyTypeError,
|
|
)
|
|
from .utils import (
|
|
b64encode,
|
|
b64decode,
|
|
uint_b64encode,
|
|
uint_b64decode,
|
|
)
|
|
|
|
_AJWK = TypeVar("_AJWK", bound="AbstractJWKBase")
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
class AbstractJWKBase(ABC):
|
|
|
|
@abstractmethod
|
|
def get_kty(self) -> str:
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def get_kid(self) -> str:
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def is_sign_key(self) -> bool:
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def sign(self, message: bytes, **options) -> bytes:
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def verify(self, message: bytes, signature: bytes, **options) -> bool:
|
|
pass # pragma: no cover
|
|
|
|
@abstractmethod
|
|
def to_dict(self, public_only: bool = True) -> Dict[str, str]:
|
|
pass # pragma: no cover
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def from_dict(cls: Type[_AJWK], dct: Dict[str, object]) -> _AJWK:
|
|
pass # pragma: no cover
|
|
|
|
|
|
class OctetJWK(AbstractJWKBase):
|
|
|
|
def __init__(self, key: bytes, kid=None, **options) -> None:
|
|
super(AbstractJWKBase, self).__init__()
|
|
self.key = key
|
|
self.kid = kid
|
|
|
|
optnames = {'use', 'key_ops', 'alg', 'x5u', 'x5c', 'x5t', 'x5t#s256'}
|
|
self.options = {k: v for k, v in options.items() if k in optnames}
|
|
|
|
def get_kty(self):
|
|
return 'oct'
|
|
|
|
def get_kid(self):
|
|
return self.kid
|
|
|
|
def is_sign_key(self) -> bool:
|
|
return True
|
|
|
|
def _get_signer(self, options) -> Callable[[bytes, bytes], bytes]:
|
|
return options['signer']
|
|
|
|
def sign(self, message: bytes, **options) -> bytes:
|
|
signer = self._get_signer(options)
|
|
return signer(message, self.key)
|
|
|
|
def verify(self, message: bytes, signature: bytes, **options) -> bool:
|
|
signer = self._get_signer(options)
|
|
return hmac.compare_digest(signature, signer(message, self.key))
|
|
|
|
def to_dict(self, public_only=True):
|
|
dct = {
|
|
'kty': 'oct',
|
|
'k': b64encode(self.key),
|
|
}
|
|
dct.update(self.options)
|
|
if self.kid:
|
|
dct['kid'] = self.kid
|
|
return dct
|
|
|
|
@classmethod
|
|
def from_dict(cls, dct):
|
|
try:
|
|
return cls(b64decode(dct['k']), **dct)
|
|
except KeyError as why:
|
|
raise MalformedJWKError('k is required') from why
|
|
|
|
|
|
class RSAJWK(AbstractJWKBase):
|
|
"""
|
|
https://tools.ietf.org/html/rfc7518.html#section-6.3.1
|
|
"""
|
|
|
|
def __init__(self, keyobj: Union[RSAPrivateKey, RSAPublicKey],
|
|
**options) -> None:
|
|
super(AbstractJWKBase, self).__init__()
|
|
self.keyobj = keyobj
|
|
|
|
optnames = {'use', 'key_ops', 'alg', 'kid',
|
|
'x5u', 'x5c', 'x5t', 'x5t#s256', }
|
|
self.options = {k: v for k, v in options.items() if k in optnames}
|
|
|
|
def is_sign_key(self) -> bool:
|
|
return isinstance(self.keyobj, RSAPrivateKey)
|
|
|
|
def _get_hash_fun(self, options) -> Callable[[], HashAlgorithm]:
|
|
return options['hash_fun']
|
|
|
|
def _get_padding(self, options) -> padding.AsymmetricPadding:
|
|
try:
|
|
return options['padding']
|
|
except KeyError:
|
|
warn('you should not use RSAJWK.verify/sign without jwa '
|
|
'intermiediary, used legacy padding')
|
|
return padding.PKCS1v15()
|
|
|
|
def sign(self, message: bytes, **options) -> bytes:
|
|
if isinstance(self.keyobj, RSAPublicKey):
|
|
raise ValueError("Requires a private key.")
|
|
hash_fun = self._get_hash_fun(options)
|
|
_padding = self._get_padding(options)
|
|
return self.keyobj.sign(message, _padding, hash_fun())
|
|
|
|
def verify(self, message: bytes, signature: bytes, **options) -> bool:
|
|
hash_fun = self._get_hash_fun(options)
|
|
_padding = self._get_padding(options)
|
|
if isinstance(self.keyobj, RSAPrivateKey):
|
|
pubkey = self.keyobj.public_key()
|
|
else:
|
|
pubkey = self.keyobj
|
|
try:
|
|
pubkey.verify(signature, message, _padding, hash_fun())
|
|
return True
|
|
except InvalidSignature:
|
|
return False
|
|
|
|
def get_kty(self):
|
|
return 'RSA'
|
|
|
|
def get_kid(self):
|
|
return self.options.get('kid')
|
|
|
|
def to_dict(self, public_only=True):
|
|
dct = {
|
|
'kty': 'RSA',
|
|
}
|
|
dct.update(self.options)
|
|
|
|
if isinstance(self.keyobj, RSAPrivateKey):
|
|
priv_numbers = self.keyobj.private_numbers()
|
|
pub_numbers = priv_numbers.public_numbers
|
|
dct.update({
|
|
'e': uint_b64encode(pub_numbers.e),
|
|
'n': uint_b64encode(pub_numbers.n),
|
|
})
|
|
if not public_only:
|
|
dct.update({
|
|
'e': uint_b64encode(pub_numbers.e),
|
|
'n': uint_b64encode(pub_numbers.n),
|
|
'd': uint_b64encode(priv_numbers.d),
|
|
'p': uint_b64encode(priv_numbers.p),
|
|
'q': uint_b64encode(priv_numbers.q),
|
|
'dp': uint_b64encode(priv_numbers.dmp1),
|
|
'dq': uint_b64encode(priv_numbers.dmq1),
|
|
'qi': uint_b64encode(priv_numbers.iqmp),
|
|
})
|
|
return dct
|
|
pub_numbers = self.keyobj.public_numbers()
|
|
dct.update({
|
|
'e': uint_b64encode(pub_numbers.e),
|
|
'n': uint_b64encode(pub_numbers.n),
|
|
})
|
|
return dct
|
|
|
|
@classmethod
|
|
def from_dict(cls, dct):
|
|
if 'oth' in dct:
|
|
raise UnsupportedKeyTypeError(
|
|
'RSA keys with multiples primes are not supported')
|
|
|
|
try:
|
|
e = uint_b64decode(dct['e'])
|
|
n = uint_b64decode(dct['n'])
|
|
except KeyError as why:
|
|
raise MalformedJWKError('e and n are required') from why
|
|
pub_numbers = RSAPublicNumbers(e, n)
|
|
if 'd' not in dct:
|
|
return cls(
|
|
pub_numbers.public_key(backend=default_backend()), **dct)
|
|
d = uint_b64decode(dct['d'])
|
|
|
|
privparams = {'p', 'q', 'dp', 'dq', 'qi'}
|
|
product = set(dct.keys()) & privparams
|
|
if len(product) == 0:
|
|
p, q = rsa_recover_prime_factors(n, e, d)
|
|
priv_numbers = RSAPrivateNumbers(
|
|
d=d,
|
|
p=p,
|
|
q=q,
|
|
dmp1=rsa_crt_dmp1(d, p),
|
|
dmq1=rsa_crt_dmq1(d, q),
|
|
iqmp=rsa_crt_iqmp(p, q),
|
|
public_numbers=pub_numbers)
|
|
elif product == privparams:
|
|
priv_numbers = RSAPrivateNumbers(
|
|
d=d,
|
|
p=uint_b64decode(dct['p']),
|
|
q=uint_b64decode(dct['q']),
|
|
dmp1=uint_b64decode(dct['dp']),
|
|
dmq1=uint_b64decode(dct['dq']),
|
|
iqmp=uint_b64decode(dct['qi']),
|
|
public_numbers=pub_numbers)
|
|
else:
|
|
# If the producer includes any of the other private key parameters,
|
|
# then all of the others MUST be present, with the exception of
|
|
# "oth", which MUST only be present when more than two prime
|
|
# factors were used.
|
|
raise MalformedJWKError(
|
|
'p, q, dp, dq, qi MUST be present or'
|
|
'all of them MUST be absent')
|
|
return cls(priv_numbers.private_key(backend=default_backend()), **dct)
|
|
|
|
|
|
def supported_key_types() -> Dict[str, Type[AbstractJWKBase]]:
|
|
return {
|
|
'oct': OctetJWK,
|
|
'RSA': RSAJWK,
|
|
}
|
|
|
|
|
|
def jwk_from_dict(dct: Mapping[str, Any]) -> AbstractJWKBase:
|
|
if not isinstance(dct, dict): # pragma: no cover
|
|
raise TypeError('dct must be a dict')
|
|
if 'kty' not in dct:
|
|
raise MalformedJWKError('kty MUST be present')
|
|
|
|
supported = supported_key_types()
|
|
kty = dct['kty']
|
|
if kty not in supported:
|
|
raise UnsupportedKeyTypeError('unsupported key type: {}'.format(kty))
|
|
return supported[kty].from_dict(dct)
|
|
|
|
|
|
PublicKeyLoaderT = Union[str, Callable[[bytes, object], object]]
|
|
PrivateKeyLoaderT = Union[
|
|
str,
|
|
Callable[[bytes, Optional[str], object], object]]
|
|
_Loader = TypeVar("_Loader", PublicKeyLoaderT, PrivateKeyLoaderT)
|
|
_C = TypeVar("_C", bound=Callable[..., Any])
|
|
|
|
|
|
# The above LoaderTs should actually not be Union, and this function should be
|
|
# typed something like this. But, this will lose any kwargs from the typing
|
|
# information. Probably needs: https://github.com/python/mypy/issues/3157
|
|
# (func: Callable[[bytes, _Loader], _T])
|
|
# -> Callable[[bytes, Union[str, _Loader]], _T]
|
|
def jwk_from_bytes_argument_conversion(func: _C) -> _C:
|
|
if not ('private' in func.__name__ or 'public' in func.__name__):
|
|
raise Exception("the wrapped function must have either public"
|
|
" or private in it's name")
|
|
|
|
@wraps(func)
|
|
def wrapper(content, loader, **kwargs):
|
|
# now convert it to a Callable if it's a string
|
|
if isinstance(loader, str):
|
|
loader = getattr(serialization_module, loader)
|
|
|
|
if kwargs.get('options') is None:
|
|
kwargs['options'] = {}
|
|
|
|
return func(content, loader, **kwargs)
|
|
return wrapper # type: ignore[return-value]
|
|
|
|
|
|
@jwk_from_bytes_argument_conversion
|
|
def jwk_from_private_bytes(
|
|
content: bytes,
|
|
private_loader: PrivateKeyLoaderT,
|
|
*,
|
|
password: Optional[str] = None,
|
|
backend: Optional[object] = None,
|
|
options: Optional[Mapping[str, object]] = None,
|
|
) -> AbstractJWKBase:
|
|
"""This function is meant to be called from jwk_from_bytes"""
|
|
if options is None:
|
|
options = {}
|
|
try:
|
|
privkey = private_loader(content, password, backend) # type: ignore[operator] # noqa: E501
|
|
if isinstance(privkey, RSAPrivateKey):
|
|
return RSAJWK(privkey, **options)
|
|
raise UnsupportedKeyTypeError('unsupported key type')
|
|
except ValueError as ex:
|
|
raise UnsupportedKeyTypeError('this is probably a public key') from ex
|
|
|
|
|
|
@jwk_from_bytes_argument_conversion
|
|
def jwk_from_public_bytes(
|
|
content: bytes,
|
|
public_loader: PublicKeyLoaderT,
|
|
*,
|
|
backend: Optional[object] = None,
|
|
options: Optional[Mapping[str, object]] = None
|
|
) -> AbstractJWKBase:
|
|
"""This function is meant to be called from jwk_from_bytes"""
|
|
if options is None:
|
|
options = {}
|
|
try:
|
|
pubkey = public_loader(content, backend) # type: ignore[operator]
|
|
if isinstance(pubkey, RSAPublicKey):
|
|
return RSAJWK(pubkey, **options)
|
|
raise UnsupportedKeyTypeError(
|
|
'unsupported key type') # pragma: no cover
|
|
except ValueError as why:
|
|
raise UnsupportedKeyTypeError('could not deserialize') from why
|
|
|
|
|
|
def jwk_from_bytes(
|
|
content: bytes,
|
|
private_loader: PrivateKeyLoaderT,
|
|
public_loader: PublicKeyLoaderT,
|
|
*,
|
|
private_password: Optional[str] = None,
|
|
backend: Optional[object] = None,
|
|
options: Optional[Mapping[str, object]] = None,
|
|
) -> AbstractJWKBase:
|
|
try:
|
|
return jwk_from_private_bytes(
|
|
content,
|
|
private_loader,
|
|
password=private_password,
|
|
backend=backend,
|
|
options=options,
|
|
)
|
|
except UnsupportedKeyTypeError:
|
|
return jwk_from_public_bytes(
|
|
content,
|
|
public_loader,
|
|
backend=backend,
|
|
options=options,
|
|
)
|
|
|
|
|
|
def jwk_from_pem(
|
|
pem_content: bytes,
|
|
private_password: Optional[str] = None,
|
|
options: Optional[Mapping[str, object]] = None,
|
|
) -> AbstractJWKBase:
|
|
return jwk_from_bytes(
|
|
pem_content,
|
|
private_loader='load_pem_private_key',
|
|
public_loader='load_pem_public_key',
|
|
private_password=private_password,
|
|
backend=None,
|
|
options=options,
|
|
)
|
|
|
|
|
|
def jwk_from_der(
|
|
der_content: bytes,
|
|
private_password: Optional[str] = None,
|
|
options: Optional[Mapping[str, object]] = None,
|
|
) -> AbstractJWKBase:
|
|
return jwk_from_bytes(
|
|
der_content,
|
|
private_loader='load_der_private_key',
|
|
public_loader='load_der_public_key',
|
|
private_password=private_password,
|
|
backend=None,
|
|
options=options,
|
|
)
|