|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import typing |
|
|
|
from cryptography.hazmat.primitives.ciphers import Cipher |
|
from cryptography.hazmat.primitives.ciphers.algorithms import AES |
|
from cryptography.hazmat.primitives.ciphers.modes import ECB |
|
from cryptography.hazmat.primitives.constant_time import bytes_eq |
|
|
|
|
|
def _wrap_core( |
|
wrapping_key: bytes, |
|
a: bytes, |
|
r: list[bytes], |
|
) -> bytes: |
|
|
|
encryptor = Cipher(AES(wrapping_key), ECB()).encryptor() |
|
n = len(r) |
|
for j in range(6): |
|
for i in range(n): |
|
|
|
|
|
|
|
b = encryptor.update(a + r[i]) |
|
a = ( |
|
int.from_bytes(b[:8], byteorder="big") ^ ((n * j) + i + 1) |
|
).to_bytes(length=8, byteorder="big") |
|
r[i] = b[-8:] |
|
|
|
assert encryptor.finalize() == b"" |
|
|
|
return a + b"".join(r) |
|
|
|
|
|
def aes_key_wrap( |
|
wrapping_key: bytes, |
|
key_to_wrap: bytes, |
|
backend: typing.Any = None, |
|
) -> bytes: |
|
if len(wrapping_key) not in [16, 24, 32]: |
|
raise ValueError("The wrapping key must be a valid AES key length") |
|
|
|
if len(key_to_wrap) < 16: |
|
raise ValueError("The key to wrap must be at least 16 bytes") |
|
|
|
if len(key_to_wrap) % 8 != 0: |
|
raise ValueError("The key to wrap must be a multiple of 8 bytes") |
|
|
|
a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" |
|
r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)] |
|
return _wrap_core(wrapping_key, a, r) |
|
|
|
|
|
def _unwrap_core( |
|
wrapping_key: bytes, |
|
a: bytes, |
|
r: list[bytes], |
|
) -> tuple[bytes, list[bytes]]: |
|
|
|
decryptor = Cipher(AES(wrapping_key), ECB()).decryptor() |
|
n = len(r) |
|
for j in reversed(range(6)): |
|
for i in reversed(range(n)): |
|
atr = ( |
|
int.from_bytes(a, byteorder="big") ^ ((n * j) + i + 1) |
|
).to_bytes(length=8, byteorder="big") + r[i] |
|
|
|
|
|
b = decryptor.update(atr) |
|
a = b[:8] |
|
r[i] = b[-8:] |
|
|
|
assert decryptor.finalize() == b"" |
|
return a, r |
|
|
|
|
|
def aes_key_wrap_with_padding( |
|
wrapping_key: bytes, |
|
key_to_wrap: bytes, |
|
backend: typing.Any = None, |
|
) -> bytes: |
|
if len(wrapping_key) not in [16, 24, 32]: |
|
raise ValueError("The wrapping key must be a valid AES key length") |
|
|
|
aiv = b"\xa6\x59\x59\xa6" + len(key_to_wrap).to_bytes( |
|
length=4, byteorder="big" |
|
) |
|
|
|
pad = (8 - (len(key_to_wrap) % 8)) % 8 |
|
key_to_wrap = key_to_wrap + b"\x00" * pad |
|
if len(key_to_wrap) == 8: |
|
|
|
encryptor = Cipher(AES(wrapping_key), ECB()).encryptor() |
|
b = encryptor.update(aiv + key_to_wrap) |
|
assert encryptor.finalize() == b"" |
|
return b |
|
else: |
|
r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)] |
|
return _wrap_core(wrapping_key, aiv, r) |
|
|
|
|
|
def aes_key_unwrap_with_padding( |
|
wrapping_key: bytes, |
|
wrapped_key: bytes, |
|
backend: typing.Any = None, |
|
) -> bytes: |
|
if len(wrapped_key) < 16: |
|
raise InvalidUnwrap("Must be at least 16 bytes") |
|
|
|
if len(wrapping_key) not in [16, 24, 32]: |
|
raise ValueError("The wrapping key must be a valid AES key length") |
|
|
|
if len(wrapped_key) == 16: |
|
|
|
decryptor = Cipher(AES(wrapping_key), ECB()).decryptor() |
|
out = decryptor.update(wrapped_key) |
|
assert decryptor.finalize() == b"" |
|
a = out[:8] |
|
data = out[8:] |
|
n = 1 |
|
else: |
|
r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)] |
|
encrypted_aiv = r.pop(0) |
|
n = len(r) |
|
a, r = _unwrap_core(wrapping_key, encrypted_aiv, r) |
|
data = b"".join(r) |
|
|
|
|
|
|
|
|
|
|
|
|
|
mli = int.from_bytes(a[4:], byteorder="big") |
|
b = (8 * n) - mli |
|
if ( |
|
not bytes_eq(a[:4], b"\xa6\x59\x59\xa6") |
|
or not 8 * (n - 1) < mli <= 8 * n |
|
or (b != 0 and not bytes_eq(data[-b:], b"\x00" * b)) |
|
): |
|
raise InvalidUnwrap() |
|
|
|
if b == 0: |
|
return data |
|
else: |
|
return data[:-b] |
|
|
|
|
|
def aes_key_unwrap( |
|
wrapping_key: bytes, |
|
wrapped_key: bytes, |
|
backend: typing.Any = None, |
|
) -> bytes: |
|
if len(wrapped_key) < 24: |
|
raise InvalidUnwrap("Must be at least 24 bytes") |
|
|
|
if len(wrapped_key) % 8 != 0: |
|
raise InvalidUnwrap("The wrapped key must be a multiple of 8 bytes") |
|
|
|
if len(wrapping_key) not in [16, 24, 32]: |
|
raise ValueError("The wrapping key must be a valid AES key length") |
|
|
|
aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" |
|
r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)] |
|
a = r.pop(0) |
|
a, r = _unwrap_core(wrapping_key, a, r) |
|
if not bytes_eq(a, aiv): |
|
raise InvalidUnwrap() |
|
|
|
return b"".join(r) |
|
|
|
|
|
class InvalidUnwrap(Exception): |
|
pass |
|
|