|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import binascii |
|
import enum |
|
import os |
|
import re |
|
import typing |
|
import warnings |
|
from base64 import encodebytes as _base64_encode |
|
from dataclasses import dataclass |
|
|
|
from cryptography import utils |
|
from cryptography.exceptions import UnsupportedAlgorithm |
|
from cryptography.hazmat.primitives import hashes |
|
from cryptography.hazmat.primitives.asymmetric import ( |
|
dsa, |
|
ec, |
|
ed25519, |
|
padding, |
|
rsa, |
|
) |
|
from cryptography.hazmat.primitives.asymmetric import utils as asym_utils |
|
from cryptography.hazmat.primitives.ciphers import ( |
|
AEADDecryptionContext, |
|
Cipher, |
|
algorithms, |
|
modes, |
|
) |
|
from cryptography.hazmat.primitives.serialization import ( |
|
Encoding, |
|
KeySerializationEncryption, |
|
NoEncryption, |
|
PrivateFormat, |
|
PublicFormat, |
|
_KeySerializationEncryption, |
|
) |
|
|
|
try: |
|
from bcrypt import kdf as _bcrypt_kdf |
|
|
|
_bcrypt_supported = True |
|
except ImportError: |
|
_bcrypt_supported = False |
|
|
|
def _bcrypt_kdf( |
|
password: bytes, |
|
salt: bytes, |
|
desired_key_bytes: int, |
|
rounds: int, |
|
ignore_few_rounds: bool = False, |
|
) -> bytes: |
|
raise UnsupportedAlgorithm("Need bcrypt module") |
|
|
|
|
|
_SSH_ED25519 = b"ssh-ed25519" |
|
_SSH_RSA = b"ssh-rsa" |
|
_SSH_DSA = b"ssh-dss" |
|
_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" |
|
_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" |
|
_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" |
|
_CERT_SUFFIX = b"[email protected]" |
|
|
|
|
|
_SK_SSH_ED25519 = b"[email protected]" |
|
_SK_SSH_ECDSA_NISTP256 = b"[email protected]" |
|
|
|
|
|
|
|
_SSH_RSA_SHA256 = b"rsa-sha2-256" |
|
_SSH_RSA_SHA512 = b"rsa-sha2-512" |
|
|
|
_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") |
|
_SK_MAGIC = b"openssh-key-v1\0" |
|
_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" |
|
_SK_END = b"-----END OPENSSH PRIVATE KEY-----" |
|
_BCRYPT = b"bcrypt" |
|
_NONE = b"none" |
|
_DEFAULT_CIPHER = b"aes256-ctr" |
|
_DEFAULT_ROUNDS = 16 |
|
|
|
|
|
_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) |
|
|
|
|
|
_PADDING = memoryview(bytearray(range(1, 1 + 16))) |
|
|
|
|
|
@dataclass |
|
class _SSHCipher: |
|
alg: type[algorithms.AES] |
|
key_len: int |
|
mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM] |
|
block_len: int |
|
iv_len: int |
|
tag_len: int | None |
|
is_aead: bool |
|
|
|
|
|
|
|
_SSH_CIPHERS: dict[bytes, _SSHCipher] = { |
|
b"aes256-ctr": _SSHCipher( |
|
alg=algorithms.AES, |
|
key_len=32, |
|
mode=modes.CTR, |
|
block_len=16, |
|
iv_len=16, |
|
tag_len=None, |
|
is_aead=False, |
|
), |
|
b"aes256-cbc": _SSHCipher( |
|
alg=algorithms.AES, |
|
key_len=32, |
|
mode=modes.CBC, |
|
block_len=16, |
|
iv_len=16, |
|
tag_len=None, |
|
is_aead=False, |
|
), |
|
b"[email protected]": _SSHCipher( |
|
alg=algorithms.AES, |
|
key_len=32, |
|
mode=modes.GCM, |
|
block_len=16, |
|
iv_len=12, |
|
tag_len=16, |
|
is_aead=True, |
|
), |
|
} |
|
|
|
|
|
_ECDSA_KEY_TYPE = { |
|
"secp256r1": _ECDSA_NISTP256, |
|
"secp384r1": _ECDSA_NISTP384, |
|
"secp521r1": _ECDSA_NISTP521, |
|
} |
|
|
|
|
|
def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes: |
|
if isinstance(key, ec.EllipticCurvePrivateKey): |
|
key_type = _ecdsa_key_type(key.public_key()) |
|
elif isinstance(key, ec.EllipticCurvePublicKey): |
|
key_type = _ecdsa_key_type(key) |
|
elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): |
|
key_type = _SSH_RSA |
|
elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): |
|
key_type = _SSH_DSA |
|
elif isinstance( |
|
key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) |
|
): |
|
key_type = _SSH_ED25519 |
|
else: |
|
raise ValueError("Unsupported key type") |
|
|
|
return key_type |
|
|
|
|
|
def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: |
|
"""Return SSH key_type and curve_name for private key.""" |
|
curve = public_key.curve |
|
if curve.name not in _ECDSA_KEY_TYPE: |
|
raise ValueError( |
|
f"Unsupported curve for ssh private key: {curve.name!r}" |
|
) |
|
return _ECDSA_KEY_TYPE[curve.name] |
|
|
|
|
|
def _ssh_pem_encode( |
|
data: bytes, |
|
prefix: bytes = _SK_START + b"\n", |
|
suffix: bytes = _SK_END + b"\n", |
|
) -> bytes: |
|
return b"".join([prefix, _base64_encode(data), suffix]) |
|
|
|
|
|
def _check_block_size(data: bytes, block_len: int) -> None: |
|
"""Require data to be full blocks""" |
|
if not data or len(data) % block_len != 0: |
|
raise ValueError("Corrupt data: missing padding") |
|
|
|
|
|
def _check_empty(data: bytes) -> None: |
|
"""All data should have been parsed.""" |
|
if data: |
|
raise ValueError("Corrupt data: unparsed data") |
|
|
|
|
|
def _init_cipher( |
|
ciphername: bytes, |
|
password: bytes | None, |
|
salt: bytes, |
|
rounds: int, |
|
) -> Cipher[modes.CBC | modes.CTR | modes.GCM]: |
|
"""Generate key + iv and return cipher.""" |
|
if not password: |
|
raise ValueError("Key is password-protected.") |
|
|
|
ciph = _SSH_CIPHERS[ciphername] |
|
seed = _bcrypt_kdf( |
|
password, salt, ciph.key_len + ciph.iv_len, rounds, True |
|
) |
|
return Cipher( |
|
ciph.alg(seed[: ciph.key_len]), |
|
ciph.mode(seed[ciph.key_len :]), |
|
) |
|
|
|
|
|
def _get_u32(data: memoryview) -> tuple[int, memoryview]: |
|
"""Uint32""" |
|
if len(data) < 4: |
|
raise ValueError("Invalid data") |
|
return int.from_bytes(data[:4], byteorder="big"), data[4:] |
|
|
|
|
|
def _get_u64(data: memoryview) -> tuple[int, memoryview]: |
|
"""Uint64""" |
|
if len(data) < 8: |
|
raise ValueError("Invalid data") |
|
return int.from_bytes(data[:8], byteorder="big"), data[8:] |
|
|
|
|
|
def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]: |
|
"""Bytes with u32 length prefix""" |
|
n, data = _get_u32(data) |
|
if n > len(data): |
|
raise ValueError("Invalid data") |
|
return data[:n], data[n:] |
|
|
|
|
|
def _get_mpint(data: memoryview) -> tuple[int, memoryview]: |
|
"""Big integer.""" |
|
val, data = _get_sshstr(data) |
|
if val and val[0] > 0x7F: |
|
raise ValueError("Invalid data") |
|
return int.from_bytes(val, "big"), data |
|
|
|
|
|
def _to_mpint(val: int) -> bytes: |
|
"""Storage format for signed bigint.""" |
|
if val < 0: |
|
raise ValueError("negative mpint not allowed") |
|
if not val: |
|
return b"" |
|
nbytes = (val.bit_length() + 8) // 8 |
|
return utils.int_to_bytes(val, nbytes) |
|
|
|
|
|
class _FragList: |
|
"""Build recursive structure without data copy.""" |
|
|
|
flist: list[bytes] |
|
|
|
def __init__(self, init: list[bytes] | None = None) -> None: |
|
self.flist = [] |
|
if init: |
|
self.flist.extend(init) |
|
|
|
def put_raw(self, val: bytes) -> None: |
|
"""Add plain bytes""" |
|
self.flist.append(val) |
|
|
|
def put_u32(self, val: int) -> None: |
|
"""Big-endian uint32""" |
|
self.flist.append(val.to_bytes(length=4, byteorder="big")) |
|
|
|
def put_u64(self, val: int) -> None: |
|
"""Big-endian uint64""" |
|
self.flist.append(val.to_bytes(length=8, byteorder="big")) |
|
|
|
def put_sshstr(self, val: bytes | _FragList) -> None: |
|
"""Bytes prefixed with u32 length""" |
|
if isinstance(val, (bytes, memoryview, bytearray)): |
|
self.put_u32(len(val)) |
|
self.flist.append(val) |
|
else: |
|
self.put_u32(val.size()) |
|
self.flist.extend(val.flist) |
|
|
|
def put_mpint(self, val: int) -> None: |
|
"""Big-endian bigint prefixed with u32 length""" |
|
self.put_sshstr(_to_mpint(val)) |
|
|
|
def size(self) -> int: |
|
"""Current number of bytes""" |
|
return sum(map(len, self.flist)) |
|
|
|
def render(self, dstbuf: memoryview, pos: int = 0) -> int: |
|
"""Write into bytearray""" |
|
for frag in self.flist: |
|
flen = len(frag) |
|
start, pos = pos, pos + flen |
|
dstbuf[start:pos] = frag |
|
return pos |
|
|
|
def tobytes(self) -> bytes: |
|
"""Return as bytes""" |
|
buf = memoryview(bytearray(self.size())) |
|
self.render(buf) |
|
return buf.tobytes() |
|
|
|
|
|
class _SSHFormatRSA: |
|
"""Format for RSA keys. |
|
|
|
Public: |
|
mpint e, n |
|
Private: |
|
mpint n, e, d, iqmp, p, q |
|
""" |
|
|
|
def get_public( |
|
self, data: memoryview |
|
) -> tuple[tuple[int, int], memoryview]: |
|
"""RSA public fields""" |
|
e, data = _get_mpint(data) |
|
n, data = _get_mpint(data) |
|
return (e, n), data |
|
|
|
def load_public( |
|
self, data: memoryview |
|
) -> tuple[rsa.RSAPublicKey, memoryview]: |
|
"""Make RSA public key from data.""" |
|
(e, n), data = self.get_public(data) |
|
public_numbers = rsa.RSAPublicNumbers(e, n) |
|
public_key = public_numbers.public_key() |
|
return public_key, data |
|
|
|
def load_private( |
|
self, data: memoryview, pubfields |
|
) -> tuple[rsa.RSAPrivateKey, memoryview]: |
|
"""Make RSA private key from data.""" |
|
n, data = _get_mpint(data) |
|
e, data = _get_mpint(data) |
|
d, data = _get_mpint(data) |
|
iqmp, data = _get_mpint(data) |
|
p, data = _get_mpint(data) |
|
q, data = _get_mpint(data) |
|
|
|
if (e, n) != pubfields: |
|
raise ValueError("Corrupt data: rsa field mismatch") |
|
dmp1 = rsa.rsa_crt_dmp1(d, p) |
|
dmq1 = rsa.rsa_crt_dmq1(d, q) |
|
public_numbers = rsa.RSAPublicNumbers(e, n) |
|
private_numbers = rsa.RSAPrivateNumbers( |
|
p, q, d, dmp1, dmq1, iqmp, public_numbers |
|
) |
|
private_key = private_numbers.private_key() |
|
return private_key, data |
|
|
|
def encode_public( |
|
self, public_key: rsa.RSAPublicKey, f_pub: _FragList |
|
) -> None: |
|
"""Write RSA public key""" |
|
pubn = public_key.public_numbers() |
|
f_pub.put_mpint(pubn.e) |
|
f_pub.put_mpint(pubn.n) |
|
|
|
def encode_private( |
|
self, private_key: rsa.RSAPrivateKey, f_priv: _FragList |
|
) -> None: |
|
"""Write RSA private key""" |
|
private_numbers = private_key.private_numbers() |
|
public_numbers = private_numbers.public_numbers |
|
|
|
f_priv.put_mpint(public_numbers.n) |
|
f_priv.put_mpint(public_numbers.e) |
|
|
|
f_priv.put_mpint(private_numbers.d) |
|
f_priv.put_mpint(private_numbers.iqmp) |
|
f_priv.put_mpint(private_numbers.p) |
|
f_priv.put_mpint(private_numbers.q) |
|
|
|
|
|
class _SSHFormatDSA: |
|
"""Format for DSA keys. |
|
|
|
Public: |
|
mpint p, q, g, y |
|
Private: |
|
mpint p, q, g, y, x |
|
""" |
|
|
|
def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: |
|
"""DSA public fields""" |
|
p, data = _get_mpint(data) |
|
q, data = _get_mpint(data) |
|
g, data = _get_mpint(data) |
|
y, data = _get_mpint(data) |
|
return (p, q, g, y), data |
|
|
|
def load_public( |
|
self, data: memoryview |
|
) -> tuple[dsa.DSAPublicKey, memoryview]: |
|
"""Make DSA public key from data.""" |
|
(p, q, g, y), data = self.get_public(data) |
|
parameter_numbers = dsa.DSAParameterNumbers(p, q, g) |
|
public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) |
|
self._validate(public_numbers) |
|
public_key = public_numbers.public_key() |
|
return public_key, data |
|
|
|
def load_private( |
|
self, data: memoryview, pubfields |
|
) -> tuple[dsa.DSAPrivateKey, memoryview]: |
|
"""Make DSA private key from data.""" |
|
(p, q, g, y), data = self.get_public(data) |
|
x, data = _get_mpint(data) |
|
|
|
if (p, q, g, y) != pubfields: |
|
raise ValueError("Corrupt data: dsa field mismatch") |
|
parameter_numbers = dsa.DSAParameterNumbers(p, q, g) |
|
public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) |
|
self._validate(public_numbers) |
|
private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) |
|
private_key = private_numbers.private_key() |
|
return private_key, data |
|
|
|
def encode_public( |
|
self, public_key: dsa.DSAPublicKey, f_pub: _FragList |
|
) -> None: |
|
"""Write DSA public key""" |
|
public_numbers = public_key.public_numbers() |
|
parameter_numbers = public_numbers.parameter_numbers |
|
self._validate(public_numbers) |
|
|
|
f_pub.put_mpint(parameter_numbers.p) |
|
f_pub.put_mpint(parameter_numbers.q) |
|
f_pub.put_mpint(parameter_numbers.g) |
|
f_pub.put_mpint(public_numbers.y) |
|
|
|
def encode_private( |
|
self, private_key: dsa.DSAPrivateKey, f_priv: _FragList |
|
) -> None: |
|
"""Write DSA private key""" |
|
self.encode_public(private_key.public_key(), f_priv) |
|
f_priv.put_mpint(private_key.private_numbers().x) |
|
|
|
def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: |
|
parameter_numbers = public_numbers.parameter_numbers |
|
if parameter_numbers.p.bit_length() != 1024: |
|
raise ValueError("SSH supports only 1024 bit DSA keys") |
|
|
|
|
|
class _SSHFormatECDSA: |
|
"""Format for ECDSA keys. |
|
|
|
Public: |
|
str curve |
|
bytes point |
|
Private: |
|
str curve |
|
bytes point |
|
mpint secret |
|
""" |
|
|
|
def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): |
|
self.ssh_curve_name = ssh_curve_name |
|
self.curve = curve |
|
|
|
def get_public( |
|
self, data: memoryview |
|
) -> tuple[tuple[memoryview, memoryview], memoryview]: |
|
"""ECDSA public fields""" |
|
curve, data = _get_sshstr(data) |
|
point, data = _get_sshstr(data) |
|
if curve != self.ssh_curve_name: |
|
raise ValueError("Curve name mismatch") |
|
if point[0] != 4: |
|
raise NotImplementedError("Need uncompressed point") |
|
return (curve, point), data |
|
|
|
def load_public( |
|
self, data: memoryview |
|
) -> tuple[ec.EllipticCurvePublicKey, memoryview]: |
|
"""Make ECDSA public key from data.""" |
|
(_, point), data = self.get_public(data) |
|
public_key = ec.EllipticCurvePublicKey.from_encoded_point( |
|
self.curve, point.tobytes() |
|
) |
|
return public_key, data |
|
|
|
def load_private( |
|
self, data: memoryview, pubfields |
|
) -> tuple[ec.EllipticCurvePrivateKey, memoryview]: |
|
"""Make ECDSA private key from data.""" |
|
(curve_name, point), data = self.get_public(data) |
|
secret, data = _get_mpint(data) |
|
|
|
if (curve_name, point) != pubfields: |
|
raise ValueError("Corrupt data: ecdsa field mismatch") |
|
private_key = ec.derive_private_key(secret, self.curve) |
|
return private_key, data |
|
|
|
def encode_public( |
|
self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList |
|
) -> None: |
|
"""Write ECDSA public key""" |
|
point = public_key.public_bytes( |
|
Encoding.X962, PublicFormat.UncompressedPoint |
|
) |
|
f_pub.put_sshstr(self.ssh_curve_name) |
|
f_pub.put_sshstr(point) |
|
|
|
def encode_private( |
|
self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList |
|
) -> None: |
|
"""Write ECDSA private key""" |
|
public_key = private_key.public_key() |
|
private_numbers = private_key.private_numbers() |
|
|
|
self.encode_public(public_key, f_priv) |
|
f_priv.put_mpint(private_numbers.private_value) |
|
|
|
|
|
class _SSHFormatEd25519: |
|
"""Format for Ed25519 keys. |
|
|
|
Public: |
|
bytes point |
|
Private: |
|
bytes point |
|
bytes secret_and_point |
|
""" |
|
|
|
def get_public( |
|
self, data: memoryview |
|
) -> tuple[tuple[memoryview], memoryview]: |
|
"""Ed25519 public fields""" |
|
point, data = _get_sshstr(data) |
|
return (point,), data |
|
|
|
def load_public( |
|
self, data: memoryview |
|
) -> tuple[ed25519.Ed25519PublicKey, memoryview]: |
|
"""Make Ed25519 public key from data.""" |
|
(point,), data = self.get_public(data) |
|
public_key = ed25519.Ed25519PublicKey.from_public_bytes( |
|
point.tobytes() |
|
) |
|
return public_key, data |
|
|
|
def load_private( |
|
self, data: memoryview, pubfields |
|
) -> tuple[ed25519.Ed25519PrivateKey, memoryview]: |
|
"""Make Ed25519 private key from data.""" |
|
(point,), data = self.get_public(data) |
|
keypair, data = _get_sshstr(data) |
|
|
|
secret = keypair[:32] |
|
point2 = keypair[32:] |
|
if point != point2 or (point,) != pubfields: |
|
raise ValueError("Corrupt data: ed25519 field mismatch") |
|
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) |
|
return private_key, data |
|
|
|
def encode_public( |
|
self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList |
|
) -> None: |
|
"""Write Ed25519 public key""" |
|
raw_public_key = public_key.public_bytes( |
|
Encoding.Raw, PublicFormat.Raw |
|
) |
|
f_pub.put_sshstr(raw_public_key) |
|
|
|
def encode_private( |
|
self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList |
|
) -> None: |
|
"""Write Ed25519 private key""" |
|
public_key = private_key.public_key() |
|
raw_private_key = private_key.private_bytes( |
|
Encoding.Raw, PrivateFormat.Raw, NoEncryption() |
|
) |
|
raw_public_key = public_key.public_bytes( |
|
Encoding.Raw, PublicFormat.Raw |
|
) |
|
f_keypair = _FragList([raw_private_key, raw_public_key]) |
|
|
|
self.encode_public(public_key, f_priv) |
|
f_priv.put_sshstr(f_keypair) |
|
|
|
|
|
def load_application(data) -> tuple[memoryview, memoryview]: |
|
""" |
|
U2F application strings |
|
""" |
|
application, data = _get_sshstr(data) |
|
if not application.tobytes().startswith(b"ssh:"): |
|
raise ValueError( |
|
"U2F application string does not start with b'ssh:' " |
|
f"({application})" |
|
) |
|
return application, data |
|
|
|
|
|
class _SSHFormatSKEd25519: |
|
""" |
|
The format of a [email protected] public key is: |
|
|
|
string "[email protected]" |
|
string public key |
|
string application (user-specified, but typically "ssh:") |
|
""" |
|
|
|
def load_public( |
|
self, data: memoryview |
|
) -> tuple[ed25519.Ed25519PublicKey, memoryview]: |
|
"""Make Ed25519 public key from data.""" |
|
public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data) |
|
_, data = load_application(data) |
|
return public_key, data |
|
|
|
|
|
class _SSHFormatSKECDSA: |
|
""" |
|
The format of a [email protected] public key is: |
|
|
|
string "[email protected]" |
|
string curve name |
|
ec_point Q |
|
string application (user-specified, but typically "ssh:") |
|
""" |
|
|
|
def load_public( |
|
self, data: memoryview |
|
) -> tuple[ec.EllipticCurvePublicKey, memoryview]: |
|
"""Make ECDSA public key from data.""" |
|
public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data) |
|
_, data = load_application(data) |
|
return public_key, data |
|
|
|
|
|
_KEY_FORMATS = { |
|
_SSH_RSA: _SSHFormatRSA(), |
|
_SSH_DSA: _SSHFormatDSA(), |
|
_SSH_ED25519: _SSHFormatEd25519(), |
|
_ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), |
|
_ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), |
|
_ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), |
|
_SK_SSH_ED25519: _SSHFormatSKEd25519(), |
|
_SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(), |
|
} |
|
|
|
|
|
def _lookup_kformat(key_type: bytes): |
|
"""Return valid format or throw error""" |
|
if not isinstance(key_type, bytes): |
|
key_type = memoryview(key_type).tobytes() |
|
if key_type in _KEY_FORMATS: |
|
return _KEY_FORMATS[key_type] |
|
raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") |
|
|
|
|
|
SSHPrivateKeyTypes = typing.Union[ |
|
ec.EllipticCurvePrivateKey, |
|
rsa.RSAPrivateKey, |
|
dsa.DSAPrivateKey, |
|
ed25519.Ed25519PrivateKey, |
|
] |
|
|
|
|
|
def load_ssh_private_key( |
|
data: bytes, |
|
password: bytes | None, |
|
backend: typing.Any = None, |
|
) -> SSHPrivateKeyTypes: |
|
"""Load private key from OpenSSH custom encoding.""" |
|
utils._check_byteslike("data", data) |
|
if password is not None: |
|
utils._check_bytes("password", password) |
|
|
|
m = _PEM_RC.search(data) |
|
if not m: |
|
raise ValueError("Not OpenSSH private key format") |
|
p1 = m.start(1) |
|
p2 = m.end(1) |
|
data = binascii.a2b_base64(memoryview(data)[p1:p2]) |
|
if not data.startswith(_SK_MAGIC): |
|
raise ValueError("Not OpenSSH private key format") |
|
data = memoryview(data)[len(_SK_MAGIC) :] |
|
|
|
|
|
ciphername, data = _get_sshstr(data) |
|
kdfname, data = _get_sshstr(data) |
|
kdfoptions, data = _get_sshstr(data) |
|
nkeys, data = _get_u32(data) |
|
if nkeys != 1: |
|
raise ValueError("Only one key supported") |
|
|
|
|
|
pubdata, data = _get_sshstr(data) |
|
pub_key_type, pubdata = _get_sshstr(pubdata) |
|
kformat = _lookup_kformat(pub_key_type) |
|
pubfields, pubdata = kformat.get_public(pubdata) |
|
_check_empty(pubdata) |
|
|
|
if (ciphername, kdfname) != (_NONE, _NONE): |
|
ciphername_bytes = ciphername.tobytes() |
|
if ciphername_bytes not in _SSH_CIPHERS: |
|
raise UnsupportedAlgorithm( |
|
f"Unsupported cipher: {ciphername_bytes!r}" |
|
) |
|
if kdfname != _BCRYPT: |
|
raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") |
|
blklen = _SSH_CIPHERS[ciphername_bytes].block_len |
|
tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len |
|
|
|
edata, data = _get_sshstr(data) |
|
|
|
|
|
if _SSH_CIPHERS[ciphername_bytes].is_aead: |
|
tag = bytes(data) |
|
if len(tag) != tag_len: |
|
raise ValueError("Corrupt data: invalid tag length for cipher") |
|
else: |
|
_check_empty(data) |
|
_check_block_size(edata, blklen) |
|
salt, kbuf = _get_sshstr(kdfoptions) |
|
rounds, kbuf = _get_u32(kbuf) |
|
_check_empty(kbuf) |
|
ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) |
|
dec = ciph.decryptor() |
|
edata = memoryview(dec.update(edata)) |
|
if _SSH_CIPHERS[ciphername_bytes].is_aead: |
|
assert isinstance(dec, AEADDecryptionContext) |
|
_check_empty(dec.finalize_with_tag(tag)) |
|
else: |
|
|
|
|
|
_check_empty(dec.finalize()) |
|
else: |
|
|
|
edata, data = _get_sshstr(data) |
|
_check_empty(data) |
|
blklen = 8 |
|
_check_block_size(edata, blklen) |
|
ck1, edata = _get_u32(edata) |
|
ck2, edata = _get_u32(edata) |
|
if ck1 != ck2: |
|
raise ValueError("Corrupt data: broken checksum") |
|
|
|
|
|
key_type, edata = _get_sshstr(edata) |
|
if key_type != pub_key_type: |
|
raise ValueError("Corrupt data: key type mismatch") |
|
private_key, edata = kformat.load_private(edata, pubfields) |
|
|
|
_, edata = _get_sshstr(edata) |
|
|
|
|
|
|
|
if edata != _PADDING[: len(edata)]: |
|
raise ValueError("Corrupt data: invalid padding") |
|
|
|
if isinstance(private_key, dsa.DSAPrivateKey): |
|
warnings.warn( |
|
"SSH DSA keys are deprecated and will be removed in a future " |
|
"release.", |
|
utils.DeprecatedIn40, |
|
stacklevel=2, |
|
) |
|
|
|
return private_key |
|
|
|
|
|
def _serialize_ssh_private_key( |
|
private_key: SSHPrivateKeyTypes, |
|
password: bytes, |
|
encryption_algorithm: KeySerializationEncryption, |
|
) -> bytes: |
|
"""Serialize private key with OpenSSH custom encoding.""" |
|
utils._check_bytes("password", password) |
|
if isinstance(private_key, dsa.DSAPrivateKey): |
|
warnings.warn( |
|
"SSH DSA key support is deprecated and will be " |
|
"removed in a future release", |
|
utils.DeprecatedIn40, |
|
stacklevel=4, |
|
) |
|
|
|
key_type = _get_ssh_key_type(private_key) |
|
kformat = _lookup_kformat(key_type) |
|
|
|
|
|
f_kdfoptions = _FragList() |
|
if password: |
|
ciphername = _DEFAULT_CIPHER |
|
blklen = _SSH_CIPHERS[ciphername].block_len |
|
kdfname = _BCRYPT |
|
rounds = _DEFAULT_ROUNDS |
|
if ( |
|
isinstance(encryption_algorithm, _KeySerializationEncryption) |
|
and encryption_algorithm._kdf_rounds is not None |
|
): |
|
rounds = encryption_algorithm._kdf_rounds |
|
salt = os.urandom(16) |
|
f_kdfoptions.put_sshstr(salt) |
|
f_kdfoptions.put_u32(rounds) |
|
ciph = _init_cipher(ciphername, password, salt, rounds) |
|
else: |
|
ciphername = kdfname = _NONE |
|
blklen = 8 |
|
ciph = None |
|
nkeys = 1 |
|
checkval = os.urandom(4) |
|
comment = b"" |
|
|
|
|
|
f_public_key = _FragList() |
|
f_public_key.put_sshstr(key_type) |
|
kformat.encode_public(private_key.public_key(), f_public_key) |
|
|
|
f_secrets = _FragList([checkval, checkval]) |
|
f_secrets.put_sshstr(key_type) |
|
kformat.encode_private(private_key, f_secrets) |
|
f_secrets.put_sshstr(comment) |
|
f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) |
|
|
|
|
|
f_main = _FragList() |
|
f_main.put_raw(_SK_MAGIC) |
|
f_main.put_sshstr(ciphername) |
|
f_main.put_sshstr(kdfname) |
|
f_main.put_sshstr(f_kdfoptions) |
|
f_main.put_u32(nkeys) |
|
f_main.put_sshstr(f_public_key) |
|
f_main.put_sshstr(f_secrets) |
|
|
|
|
|
slen = f_secrets.size() |
|
mlen = f_main.size() |
|
buf = memoryview(bytearray(mlen + blklen)) |
|
f_main.render(buf) |
|
ofs = mlen - slen |
|
|
|
|
|
if ciph is not None: |
|
ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) |
|
|
|
return _ssh_pem_encode(buf[:mlen]) |
|
|
|
|
|
SSHPublicKeyTypes = typing.Union[ |
|
ec.EllipticCurvePublicKey, |
|
rsa.RSAPublicKey, |
|
dsa.DSAPublicKey, |
|
ed25519.Ed25519PublicKey, |
|
] |
|
|
|
SSHCertPublicKeyTypes = typing.Union[ |
|
ec.EllipticCurvePublicKey, |
|
rsa.RSAPublicKey, |
|
ed25519.Ed25519PublicKey, |
|
] |
|
|
|
|
|
class SSHCertificateType(enum.Enum): |
|
USER = 1 |
|
HOST = 2 |
|
|
|
|
|
class SSHCertificate: |
|
def __init__( |
|
self, |
|
_nonce: memoryview, |
|
_public_key: SSHPublicKeyTypes, |
|
_serial: int, |
|
_cctype: int, |
|
_key_id: memoryview, |
|
_valid_principals: list[bytes], |
|
_valid_after: int, |
|
_valid_before: int, |
|
_critical_options: dict[bytes, bytes], |
|
_extensions: dict[bytes, bytes], |
|
_sig_type: memoryview, |
|
_sig_key: memoryview, |
|
_inner_sig_type: memoryview, |
|
_signature: memoryview, |
|
_tbs_cert_body: memoryview, |
|
_cert_key_type: bytes, |
|
_cert_body: memoryview, |
|
): |
|
self._nonce = _nonce |
|
self._public_key = _public_key |
|
self._serial = _serial |
|
try: |
|
self._type = SSHCertificateType(_cctype) |
|
except ValueError: |
|
raise ValueError("Invalid certificate type") |
|
self._key_id = _key_id |
|
self._valid_principals = _valid_principals |
|
self._valid_after = _valid_after |
|
self._valid_before = _valid_before |
|
self._critical_options = _critical_options |
|
self._extensions = _extensions |
|
self._sig_type = _sig_type |
|
self._sig_key = _sig_key |
|
self._inner_sig_type = _inner_sig_type |
|
self._signature = _signature |
|
self._cert_key_type = _cert_key_type |
|
self._cert_body = _cert_body |
|
self._tbs_cert_body = _tbs_cert_body |
|
|
|
@property |
|
def nonce(self) -> bytes: |
|
return bytes(self._nonce) |
|
|
|
def public_key(self) -> SSHCertPublicKeyTypes: |
|
|
|
|
|
return typing.cast(SSHCertPublicKeyTypes, self._public_key) |
|
|
|
@property |
|
def serial(self) -> int: |
|
return self._serial |
|
|
|
@property |
|
def type(self) -> SSHCertificateType: |
|
return self._type |
|
|
|
@property |
|
def key_id(self) -> bytes: |
|
return bytes(self._key_id) |
|
|
|
@property |
|
def valid_principals(self) -> list[bytes]: |
|
return self._valid_principals |
|
|
|
@property |
|
def valid_before(self) -> int: |
|
return self._valid_before |
|
|
|
@property |
|
def valid_after(self) -> int: |
|
return self._valid_after |
|
|
|
@property |
|
def critical_options(self) -> dict[bytes, bytes]: |
|
return self._critical_options |
|
|
|
@property |
|
def extensions(self) -> dict[bytes, bytes]: |
|
return self._extensions |
|
|
|
def signature_key(self) -> SSHCertPublicKeyTypes: |
|
sigformat = _lookup_kformat(self._sig_type) |
|
signature_key, sigkey_rest = sigformat.load_public(self._sig_key) |
|
_check_empty(sigkey_rest) |
|
return signature_key |
|
|
|
def public_bytes(self) -> bytes: |
|
return ( |
|
bytes(self._cert_key_type) |
|
+ b" " |
|
+ binascii.b2a_base64(bytes(self._cert_body), newline=False) |
|
) |
|
|
|
def verify_cert_signature(self) -> None: |
|
signature_key = self.signature_key() |
|
if isinstance(signature_key, ed25519.Ed25519PublicKey): |
|
signature_key.verify( |
|
bytes(self._signature), bytes(self._tbs_cert_body) |
|
) |
|
elif isinstance(signature_key, ec.EllipticCurvePublicKey): |
|
|
|
r, data = _get_mpint(self._signature) |
|
s, data = _get_mpint(data) |
|
_check_empty(data) |
|
computed_sig = asym_utils.encode_dss_signature(r, s) |
|
hash_alg = _get_ec_hash_alg(signature_key.curve) |
|
signature_key.verify( |
|
computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) |
|
) |
|
else: |
|
assert isinstance(signature_key, rsa.RSAPublicKey) |
|
if self._inner_sig_type == _SSH_RSA: |
|
hash_alg = hashes.SHA1() |
|
elif self._inner_sig_type == _SSH_RSA_SHA256: |
|
hash_alg = hashes.SHA256() |
|
else: |
|
assert self._inner_sig_type == _SSH_RSA_SHA512 |
|
hash_alg = hashes.SHA512() |
|
signature_key.verify( |
|
bytes(self._signature), |
|
bytes(self._tbs_cert_body), |
|
padding.PKCS1v15(), |
|
hash_alg, |
|
) |
|
|
|
|
|
def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: |
|
if isinstance(curve, ec.SECP256R1): |
|
return hashes.SHA256() |
|
elif isinstance(curve, ec.SECP384R1): |
|
return hashes.SHA384() |
|
else: |
|
assert isinstance(curve, ec.SECP521R1) |
|
return hashes.SHA512() |
|
|
|
|
|
def _load_ssh_public_identity( |
|
data: bytes, |
|
_legacy_dsa_allowed=False, |
|
) -> SSHCertificate | SSHPublicKeyTypes: |
|
utils._check_byteslike("data", data) |
|
|
|
m = _SSH_PUBKEY_RC.match(data) |
|
if not m: |
|
raise ValueError("Invalid line format") |
|
key_type = orig_key_type = m.group(1) |
|
key_body = m.group(2) |
|
with_cert = False |
|
if key_type.endswith(_CERT_SUFFIX): |
|
with_cert = True |
|
key_type = key_type[: -len(_CERT_SUFFIX)] |
|
if key_type == _SSH_DSA and not _legacy_dsa_allowed: |
|
raise UnsupportedAlgorithm( |
|
"DSA keys aren't supported in SSH certificates" |
|
) |
|
kformat = _lookup_kformat(key_type) |
|
|
|
try: |
|
rest = memoryview(binascii.a2b_base64(key_body)) |
|
except (TypeError, binascii.Error): |
|
raise ValueError("Invalid format") |
|
|
|
if with_cert: |
|
cert_body = rest |
|
inner_key_type, rest = _get_sshstr(rest) |
|
if inner_key_type != orig_key_type: |
|
raise ValueError("Invalid key format") |
|
if with_cert: |
|
nonce, rest = _get_sshstr(rest) |
|
public_key, rest = kformat.load_public(rest) |
|
if with_cert: |
|
serial, rest = _get_u64(rest) |
|
cctype, rest = _get_u32(rest) |
|
key_id, rest = _get_sshstr(rest) |
|
principals, rest = _get_sshstr(rest) |
|
valid_principals = [] |
|
while principals: |
|
principal, principals = _get_sshstr(principals) |
|
valid_principals.append(bytes(principal)) |
|
valid_after, rest = _get_u64(rest) |
|
valid_before, rest = _get_u64(rest) |
|
crit_options, rest = _get_sshstr(rest) |
|
critical_options = _parse_exts_opts(crit_options) |
|
exts, rest = _get_sshstr(rest) |
|
extensions = _parse_exts_opts(exts) |
|
|
|
_, rest = _get_sshstr(rest) |
|
sig_key_raw, rest = _get_sshstr(rest) |
|
sig_type, sig_key = _get_sshstr(sig_key_raw) |
|
if sig_type == _SSH_DSA and not _legacy_dsa_allowed: |
|
raise UnsupportedAlgorithm( |
|
"DSA signatures aren't supported in SSH certificates" |
|
) |
|
|
|
tbs_cert_body = cert_body[: -len(rest)] |
|
signature_raw, rest = _get_sshstr(rest) |
|
_check_empty(rest) |
|
inner_sig_type, sig_rest = _get_sshstr(signature_raw) |
|
|
|
if ( |
|
sig_type == _SSH_RSA |
|
and inner_sig_type |
|
not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] |
|
) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): |
|
raise ValueError("Signature key type does not match") |
|
signature, sig_rest = _get_sshstr(sig_rest) |
|
_check_empty(sig_rest) |
|
return SSHCertificate( |
|
nonce, |
|
public_key, |
|
serial, |
|
cctype, |
|
key_id, |
|
valid_principals, |
|
valid_after, |
|
valid_before, |
|
critical_options, |
|
extensions, |
|
sig_type, |
|
sig_key, |
|
inner_sig_type, |
|
signature, |
|
tbs_cert_body, |
|
orig_key_type, |
|
cert_body, |
|
) |
|
else: |
|
_check_empty(rest) |
|
return public_key |
|
|
|
|
|
def load_ssh_public_identity( |
|
data: bytes, |
|
) -> SSHCertificate | SSHPublicKeyTypes: |
|
return _load_ssh_public_identity(data) |
|
|
|
|
|
def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]: |
|
result: dict[bytes, bytes] = {} |
|
last_name = None |
|
while exts_opts: |
|
name, exts_opts = _get_sshstr(exts_opts) |
|
bname: bytes = bytes(name) |
|
if bname in result: |
|
raise ValueError("Duplicate name") |
|
if last_name is not None and bname < last_name: |
|
raise ValueError("Fields not lexically sorted") |
|
value, exts_opts = _get_sshstr(exts_opts) |
|
if len(value) > 0: |
|
value, extra = _get_sshstr(value) |
|
if len(extra) > 0: |
|
raise ValueError("Unexpected extra data after value") |
|
result[bname] = bytes(value) |
|
last_name = bname |
|
return result |
|
|
|
|
|
def load_ssh_public_key( |
|
data: bytes, backend: typing.Any = None |
|
) -> SSHPublicKeyTypes: |
|
cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) |
|
public_key: SSHPublicKeyTypes |
|
if isinstance(cert_or_key, SSHCertificate): |
|
public_key = cert_or_key.public_key() |
|
else: |
|
public_key = cert_or_key |
|
|
|
if isinstance(public_key, dsa.DSAPublicKey): |
|
warnings.warn( |
|
"SSH DSA keys are deprecated and will be removed in a future " |
|
"release.", |
|
utils.DeprecatedIn40, |
|
stacklevel=2, |
|
) |
|
return public_key |
|
|
|
|
|
def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: |
|
"""One-line public key format for OpenSSH""" |
|
if isinstance(public_key, dsa.DSAPublicKey): |
|
warnings.warn( |
|
"SSH DSA key support is deprecated and will be " |
|
"removed in a future release", |
|
utils.DeprecatedIn40, |
|
stacklevel=4, |
|
) |
|
key_type = _get_ssh_key_type(public_key) |
|
kformat = _lookup_kformat(key_type) |
|
|
|
f_pub = _FragList() |
|
f_pub.put_sshstr(key_type) |
|
kformat.encode_public(public_key, f_pub) |
|
|
|
pub = binascii.b2a_base64(f_pub.tobytes()).strip() |
|
return b"".join([key_type, b" ", pub]) |
|
|
|
|
|
SSHCertPrivateKeyTypes = typing.Union[ |
|
ec.EllipticCurvePrivateKey, |
|
rsa.RSAPrivateKey, |
|
ed25519.Ed25519PrivateKey, |
|
] |
|
|
|
|
|
|
|
|
|
_SSHKEY_CERT_MAX_PRINCIPALS = 256 |
|
|
|
|
|
class SSHCertificateBuilder: |
|
def __init__( |
|
self, |
|
_public_key: SSHCertPublicKeyTypes | None = None, |
|
_serial: int | None = None, |
|
_type: SSHCertificateType | None = None, |
|
_key_id: bytes | None = None, |
|
_valid_principals: list[bytes] = [], |
|
_valid_for_all_principals: bool = False, |
|
_valid_before: int | None = None, |
|
_valid_after: int | None = None, |
|
_critical_options: list[tuple[bytes, bytes]] = [], |
|
_extensions: list[tuple[bytes, bytes]] = [], |
|
): |
|
self._public_key = _public_key |
|
self._serial = _serial |
|
self._type = _type |
|
self._key_id = _key_id |
|
self._valid_principals = _valid_principals |
|
self._valid_for_all_principals = _valid_for_all_principals |
|
self._valid_before = _valid_before |
|
self._valid_after = _valid_after |
|
self._critical_options = _critical_options |
|
self._extensions = _extensions |
|
|
|
def public_key( |
|
self, public_key: SSHCertPublicKeyTypes |
|
) -> SSHCertificateBuilder: |
|
if not isinstance( |
|
public_key, |
|
( |
|
ec.EllipticCurvePublicKey, |
|
rsa.RSAPublicKey, |
|
ed25519.Ed25519PublicKey, |
|
), |
|
): |
|
raise TypeError("Unsupported key type") |
|
if self._public_key is not None: |
|
raise ValueError("public_key already set") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def serial(self, serial: int) -> SSHCertificateBuilder: |
|
if not isinstance(serial, int): |
|
raise TypeError("serial must be an integer") |
|
if not 0 <= serial < 2**64: |
|
raise ValueError("serial must be between 0 and 2**64") |
|
if self._serial is not None: |
|
raise ValueError("serial already set") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: |
|
if not isinstance(type, SSHCertificateType): |
|
raise TypeError("type must be an SSHCertificateType") |
|
if self._type is not None: |
|
raise ValueError("type already set") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def key_id(self, key_id: bytes) -> SSHCertificateBuilder: |
|
if not isinstance(key_id, bytes): |
|
raise TypeError("key_id must be bytes") |
|
if self._key_id is not None: |
|
raise ValueError("key_id already set") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def valid_principals( |
|
self, valid_principals: list[bytes] |
|
) -> SSHCertificateBuilder: |
|
if self._valid_for_all_principals: |
|
raise ValueError( |
|
"Principals can't be set because the cert is valid " |
|
"for all principals" |
|
) |
|
if ( |
|
not all(isinstance(x, bytes) for x in valid_principals) |
|
or not valid_principals |
|
): |
|
raise TypeError( |
|
"principals must be a list of bytes and can't be empty" |
|
) |
|
if self._valid_principals: |
|
raise ValueError("valid_principals already set") |
|
|
|
if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: |
|
raise ValueError( |
|
"Reached or exceeded the maximum number of valid_principals" |
|
) |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def valid_for_all_principals(self): |
|
if self._valid_principals: |
|
raise ValueError( |
|
"valid_principals already set, can't set " |
|
"valid_for_all_principals" |
|
) |
|
if self._valid_for_all_principals: |
|
raise ValueError("valid_for_all_principals already set") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=True, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder: |
|
if not isinstance(valid_before, (int, float)): |
|
raise TypeError("valid_before must be an int or float") |
|
valid_before = int(valid_before) |
|
if valid_before < 0 or valid_before >= 2**64: |
|
raise ValueError("valid_before must [0, 2**64)") |
|
if self._valid_before is not None: |
|
raise ValueError("valid_before already set") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder: |
|
if not isinstance(valid_after, (int, float)): |
|
raise TypeError("valid_after must be an int or float") |
|
valid_after = int(valid_after) |
|
if valid_after < 0 or valid_after >= 2**64: |
|
raise ValueError("valid_after must [0, 2**64)") |
|
if self._valid_after is not None: |
|
raise ValueError("valid_after already set") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=self._extensions, |
|
) |
|
|
|
def add_critical_option( |
|
self, name: bytes, value: bytes |
|
) -> SSHCertificateBuilder: |
|
if not isinstance(name, bytes) or not isinstance(value, bytes): |
|
raise TypeError("name and value must be bytes") |
|
|
|
if name in [name for name, _ in self._critical_options]: |
|
raise ValueError("Duplicate critical option name") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=[*self._critical_options, (name, value)], |
|
_extensions=self._extensions, |
|
) |
|
|
|
def add_extension( |
|
self, name: bytes, value: bytes |
|
) -> SSHCertificateBuilder: |
|
if not isinstance(name, bytes) or not isinstance(value, bytes): |
|
raise TypeError("name and value must be bytes") |
|
|
|
if name in [name for name, _ in self._extensions]: |
|
raise ValueError("Duplicate extension name") |
|
|
|
return SSHCertificateBuilder( |
|
_public_key=self._public_key, |
|
_serial=self._serial, |
|
_type=self._type, |
|
_key_id=self._key_id, |
|
_valid_principals=self._valid_principals, |
|
_valid_for_all_principals=self._valid_for_all_principals, |
|
_valid_before=self._valid_before, |
|
_valid_after=self._valid_after, |
|
_critical_options=self._critical_options, |
|
_extensions=[*self._extensions, (name, value)], |
|
) |
|
|
|
def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: |
|
if not isinstance( |
|
private_key, |
|
( |
|
ec.EllipticCurvePrivateKey, |
|
rsa.RSAPrivateKey, |
|
ed25519.Ed25519PrivateKey, |
|
), |
|
): |
|
raise TypeError("Unsupported private key type") |
|
|
|
if self._public_key is None: |
|
raise ValueError("public_key must be set") |
|
|
|
|
|
serial = 0 if self._serial is None else self._serial |
|
|
|
if self._type is None: |
|
raise ValueError("type must be set") |
|
|
|
|
|
key_id = b"" if self._key_id is None else self._key_id |
|
|
|
|
|
|
|
|
|
|
|
if not self._valid_principals and not self._valid_for_all_principals: |
|
raise ValueError( |
|
"valid_principals must be set if valid_for_all_principals " |
|
"is False" |
|
) |
|
|
|
if self._valid_before is None: |
|
raise ValueError("valid_before must be set") |
|
|
|
if self._valid_after is None: |
|
raise ValueError("valid_after must be set") |
|
|
|
if self._valid_after > self._valid_before: |
|
raise ValueError("valid_after must be earlier than valid_before") |
|
|
|
|
|
self._critical_options.sort(key=lambda x: x[0]) |
|
self._extensions.sort(key=lambda x: x[0]) |
|
|
|
key_type = _get_ssh_key_type(self._public_key) |
|
cert_prefix = key_type + _CERT_SUFFIX |
|
|
|
|
|
nonce = os.urandom(32) |
|
kformat = _lookup_kformat(key_type) |
|
f = _FragList() |
|
f.put_sshstr(cert_prefix) |
|
f.put_sshstr(nonce) |
|
kformat.encode_public(self._public_key, f) |
|
f.put_u64(serial) |
|
f.put_u32(self._type.value) |
|
f.put_sshstr(key_id) |
|
fprincipals = _FragList() |
|
for p in self._valid_principals: |
|
fprincipals.put_sshstr(p) |
|
f.put_sshstr(fprincipals.tobytes()) |
|
f.put_u64(self._valid_after) |
|
f.put_u64(self._valid_before) |
|
fcrit = _FragList() |
|
for name, value in self._critical_options: |
|
fcrit.put_sshstr(name) |
|
if len(value) > 0: |
|
foptval = _FragList() |
|
foptval.put_sshstr(value) |
|
fcrit.put_sshstr(foptval.tobytes()) |
|
else: |
|
fcrit.put_sshstr(value) |
|
f.put_sshstr(fcrit.tobytes()) |
|
fext = _FragList() |
|
for name, value in self._extensions: |
|
fext.put_sshstr(name) |
|
if len(value) > 0: |
|
fextval = _FragList() |
|
fextval.put_sshstr(value) |
|
fext.put_sshstr(fextval.tobytes()) |
|
else: |
|
fext.put_sshstr(value) |
|
f.put_sshstr(fext.tobytes()) |
|
f.put_sshstr(b"") |
|
|
|
ca_type = _get_ssh_key_type(private_key) |
|
caformat = _lookup_kformat(ca_type) |
|
caf = _FragList() |
|
caf.put_sshstr(ca_type) |
|
caformat.encode_public(private_key.public_key(), caf) |
|
f.put_sshstr(caf.tobytes()) |
|
|
|
|
|
|
|
if isinstance(private_key, ed25519.Ed25519PrivateKey): |
|
signature = private_key.sign(f.tobytes()) |
|
fsig = _FragList() |
|
fsig.put_sshstr(ca_type) |
|
fsig.put_sshstr(signature) |
|
f.put_sshstr(fsig.tobytes()) |
|
elif isinstance(private_key, ec.EllipticCurvePrivateKey): |
|
hash_alg = _get_ec_hash_alg(private_key.curve) |
|
signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) |
|
r, s = asym_utils.decode_dss_signature(signature) |
|
fsig = _FragList() |
|
fsig.put_sshstr(ca_type) |
|
fsigblob = _FragList() |
|
fsigblob.put_mpint(r) |
|
fsigblob.put_mpint(s) |
|
fsig.put_sshstr(fsigblob.tobytes()) |
|
f.put_sshstr(fsig.tobytes()) |
|
|
|
else: |
|
assert isinstance(private_key, rsa.RSAPrivateKey) |
|
|
|
|
|
|
|
|
|
fsig = _FragList() |
|
fsig.put_sshstr(_SSH_RSA_SHA512) |
|
signature = private_key.sign( |
|
f.tobytes(), padding.PKCS1v15(), hashes.SHA512() |
|
) |
|
fsig.put_sshstr(signature) |
|
f.put_sshstr(fsig.tobytes()) |
|
|
|
cert_data = binascii.b2a_base64(f.tobytes()).strip() |
|
|
|
|
|
|
|
return typing.cast( |
|
SSHCertificate, |
|
load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), |
|
) |
|
|