#!/usr/bin/python3 --
import sys
from binascii import a2b_base64, b2a_base64


class BadCommandError(ValueError):
    pass


prefix = b"-----BEGIN PGP SIGNED MESSAGE-----\nHash: "
sig_start = b"-----BEGIN PGP SIGNATURE-----\n\n"
sig_end = b"\n-----END PGP SIGNATURE-----\n"
shortest_sig = (
    sig_start
    + b"""
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
=AAAA"""
    + sig_end
)
shortest_input = (
    prefix
    + b"""SHA256

Build-template\x20a\n"""
    + shortest_sig
)


def check_base64(untrusted_bytes: bytes) -> None:
    for i in untrusted_bytes:
        if (
            (0x41 <= i <= 0x5A)
            or (0x61 <= i <= 0x7A)  # A-Z
            or (0x30 <= i <= 0x39)  # a-z
            or (i == 0x2F)  # 0-9
            or (i == 0x2B)  # '/'
        ):  # '+'
            pass
        else:
            raise BadCommandError("invalid base64 character")


def validate_upload_component_command(untrusted_command: bytes) -> None:
    """
    Validate an upload command.
    """
    untrusted_args = untrusted_command.split(b" ")
    if len(untrusted_args) < 6:
        raise BadCommandError(
            "Upload takes at least five arguments: component name, commit sha, "
            'release name, repo name, "all" or distributions'
        )
    if untrusted_args[0] != b"Upload-component":
        raise AssertionError("Wrong command passed to validate_upload_component_command")
    (
        untrusted_release_name,
        untrusted_component,
        untrusted_commit_sha,
        untrusted_repo_name,
    ) = untrusted_args[1:5]
    if not untrusted_component:
        raise BadCommandError("Empty component name")
    if b"." in untrusted_component:
        raise BadCommandError('"." not allowed in component name')
    allowed_repos = (b'current', b'security-testing')
    if len(untrusted_commit_sha) != 40:
        raise BadCommandError(
            "Wrong length of commit SHA: found {}, expected 40".format(
                len(untrusted_commit_sha)
            )
        )
    for i in untrusted_commit_sha:
        if (0x30 <= i <= 0x39) or (0x61 <= i <= 0x66):
            pass
        else:
            raise BadCommandError(
                "'{}' is not a valid commit SHA".format(
                    untrusted_commit_sha.decode('ascii', 'strict')
                )
            )
    if untrusted_repo_name not in allowed_repos:
        raise BadCommandError("Unsupported repository name {}".format(
            untrusted_repo_name.decode('ascii', 'strict')
        ))

    if untrusted_args[5] == b"all":
        return
    for untrusted_dist_pair in untrusted_args[5:]:
        # all other metacharacters caught sooner
        if b"." in untrusted_dist_pair:
            raise BadCommandError('Dist pair cannot contain "."')
        if untrusted_dist_pair.startswith(b"host-"):
            untrusted_dist = untrusted_dist_pair[5:]
        elif untrusted_dist_pair.startswith(b"vm-"):
            untrusted_dist = untrusted_dist_pair[3:]
        else:
            raise BadCommandError(
                'Invalid package set (must be "host-" or "vm-", without quotes)'
            )
        if not untrusted_dist:
            raise BadCommandError(f"Empty dist following {str(untrusted_dist_pair)}")
        if not (0x61 <= untrusted_dist[0] <= 0x7A):
            raise BadCommandError(
                "Package set must be followed by a lowercase ASCII letter"
            )


def validate_build_template_command(untrusted_command: bytes) -> None:
    """
    Validate a build template command.
    """
    try:
        (
            command,
            untrusted_release_name,
            untrusted_template_name,
            untrusted_timestamp,
        ) = untrusted_command.split(b" ")
    except ValueError:
        raise BadCommandError(
            "Build-template takes 3 arguments: release name, dist, and timestamp"
        )
    if command != b"Build-template":
        raise AssertionError("Wrong command passed to validate_build_template_command")
    if not untrusted_release_name:
        raise BadCommandError("Empty release name")
    if not untrusted_template_name:
        raise BadCommandError("Empty template name")
    if b"." in untrusted_template_name:
        raise BadCommandError("'.' not allowed in template name")
    timestamp_len = len(untrusted_timestamp)
    if timestamp_len != 12:
        raise BadCommandError(f"Timestamp must be 12 bytes long, not {timestamp_len}")
    for i in untrusted_timestamp:
        if not (0x30 <= i <= 0x39):
            raise BadCommandError(
                "Timestamp must be exactly 12 decimal digits, including leading zeros"
            )


def validate_upload_template_command(untrusted_command: bytes) -> None:
    """
    Validate an upload template command.
    """
    untrusted_args = untrusted_command.split(b" ")
    if len(untrusted_args) != 5:
        raise BadCommandError(
            "Upload takes four arguments: release name, template name, "
            "template sha (format RELEASE-TIMESTAMP) and repo."
        )
    if untrusted_args[0] != b"Upload-template":
        raise AssertionError("Wrong command passed to validate_upload_template_command")
    (
        untrusted_release_name,
        untrusted_template_name,
        untrusted_template_sha,
        untrusted_repo_name,
    ) = untrusted_args[1:5]
    if not untrusted_template_name:
        raise BadCommandError("Empty component name")
    if b"." in untrusted_template_name:
        raise BadCommandError('"." not allowed in component name')

    allowed_repos = (b'templates-itl', b'templates-community')
    if untrusted_repo_name not in allowed_repos:
        raise BadCommandError("Unsupported repository name {}".format(
            untrusted_repo_name.decode('ascii', 'strict')
        ))

    untrusted_timestamp = untrusted_template_sha.split(b"-")[-1]
    timestamp_len = len(untrusted_timestamp)
    if timestamp_len != 12:
        raise BadCommandError(f"Timestamp must be 12 bytes long, not {timestamp_len}")
    for i in untrusted_timestamp:
        if not (0x30 <= i <= 0x39):
            raise BadCommandError(
                "Timestamp must be exactly 12 decimal digits, including leading zeros"
            )


def validate_build_iso_command(untrusted_command: bytes) -> None:
    """
    Validate a build iso command.
    """
    try:
        (
            command,
            untrusted_release_name,
            untrusted_iso_version,
            untrusted_iso_timestamp,
        ) = untrusted_command.split(b" ")
    except ValueError:
        raise BadCommandError(
            "Build-iso takes 3 arguments: release name, iso version, and iso timestamp"
        )
    if command != b"Build-iso":
        raise AssertionError("Wrong command passed to validate_build_iso_command")
    if not untrusted_release_name:
        raise BadCommandError("Empty release name")
    if not untrusted_iso_version:
        raise BadCommandError("Empty iso version")
    timestamp_len = len(untrusted_iso_timestamp)
    if timestamp_len != 12:
        raise BadCommandError(f"Timestamp must be 12 bytes long, not {timestamp_len}")
    for i in untrusted_iso_timestamp:
        if not (0x30 <= i <= 0x39):
            raise BadCommandError(
                "Timestamp must be exactly 12 decimal digits, including leading zeros"
            )


def check_command(untrusted_command: bytes) -> None:
    """
    Check a command provided via a GitHub comment
    """
    if len(untrusted_command) > 255:
        raise BadCommandError("command too long")
    last_c = 0x20
    for i in untrusted_command:
        if (
            (0x41 <= i <= 0x5A)
            or (0x61 <= i <= 0x7A)  # A-Z
            or (0x30 <= i <= 0x39)  # a-z
            or (i == 0x5F)  # 0-9
            or (i == 0x2E)  # '_'
        ):  # '.'
            pass
        elif (i == 0x2D) or (i == 0x20):  # '-'  # ' '
            if last_c == 0x20:
                raise BadCommandError("Double spaces or space-dash in command")
        else:
            raise BadCommandError("invalid character in command")
        last_c = i
    if untrusted_command.endswith(b" "):
        raise BadCommandError("Trailing whitespace in command forbidden")
    if untrusted_command.startswith(b"Build-template "):
        validate_build_template_command(untrusted_command)
    elif untrusted_command.startswith(b"Upload-template "):
        validate_upload_template_command(untrusted_command)
    elif untrusted_command.startswith(b"Upload-component "):
        validate_upload_component_command(untrusted_command)
    elif untrusted_command.startswith(b"Build-iso "):
        validate_build_iso_command(untrusted_command)
    else:
        raise BadCommandError("Unknown command")


def check_one_signature_packet(untrusted_binary_sig: bytes) -> bytes:
    """
    Check that a byte string is exactly one OpenPGP signature packet.
    Return the contents of the packet.

    This check purely covers the OpenPGP message framing, and should
    (almost) never need to be changed.
    """
    untrusted_first_byte = untrusted_binary_sig[0]
    if not (untrusted_first_byte & 0x80):
        raise BadCommandError("first bit zero")
    packet_length = 0
    if untrusted_first_byte & 0x40:
        # new-format packet
        tag = untrusted_first_byte & 0x3F
        untrusted_first_byte = untrusted_binary_sig[1]
        if untrusted_first_byte < 192:
            packet_length = untrusted_first_byte
            len_bytes = 1
        elif untrusted_first_byte < 224:
            packet_length = (
                ((untrusted_first_byte - 192) << 8) + untrusted_binary_sig[1] + 192
            )
            len_bytes = 2
        elif untrusted_first_byte == 255:
            packet_length = 0
            len_bytes = 5
            for i in untrusted_binary_sig[2:6]:
                packet_length = packet_length << 8 | i
        else:
            raise BadCommandError("unsupported partial-length packet")
    else:
        # old-format packet
        len_bytes = 1 << (untrusted_first_byte & 0x3)
        tag = (untrusted_first_byte >> 2) & 0xF
        if len_bytes > 4:
            raise BadCommandError("forbidden indefinite-length packet")
        for i in untrusted_binary_sig[1 : len_bytes + 1]:
            packet_length = packet_length << 8 | i
    if packet_length != len(untrusted_binary_sig) - len_bytes - 1:
        raise BadCommandError("bad signature length")
    if tag != 2:
        raise BadCommandError("packet is not a signature")
    return untrusted_binary_sig[len_bytes + 1:]


def check_binary_signature(untrusted_binary_sig: bytes, hash_str: bytes) -> None:
    untrusted_binary_sig = check_one_signature_packet(untrusted_binary_sig)
    version, ty, _pubkey_alg, hash_alg = untrusted_binary_sig[:4]
    if version != 4:
        raise BadCommandError("only version 4 signatures allowed")
    if ty != 1:
        raise BadCommandError("expected signature of type 1, got something else")
    if hash_alg < 8 or hash_alg > 10:
        raise BadCommandError("unsupported hash algorithm")
    if hash_str != (b"SHA256", b"SHA384", b"SHA512")[hash_alg - 8]:
        raise BadCommandError("hash algorithm mismatch")


def reconstruct_armored_sig(
    untrusted_binary_crc24: bytes, untrusted_binary_sig: bytes
) -> bytes:
    """
    Reconstruct an armored signature from the components
    """
    untrusted_array = [sig_start[:-1]]
    untrusted_encoded = b2a_base64(untrusted_binary_sig, newline=False)
    untrusted_array += (
        untrusted_encoded[i : i + 64] for i in range(0, len(untrusted_encoded), 64)
    )
    untrusted_array.append(
        b"=" + b2a_base64(untrusted_binary_crc24, newline=False) + sig_end
    )
    return b"\n".join(untrusted_array)


def canonicalize_sig(untrusted_armored_sig: bytes, hash_str: bytes) -> bytes:
    """
    Canonicalize an ASCII-armored clear text signature.

    hash_str is the value of the Hash: armor line.
    """
    # Check that the signature is long enough
    if len(untrusted_armored_sig) < len(shortest_sig):
        raise BadCommandError("sig too short")

    # Check the start and end of the signature
    if not untrusted_armored_sig.startswith(sig_start):
        raise BadCommandError(
            "Invalid start of signature - check there are no armor headers"
        )
    if not untrusted_armored_sig.endswith(sig_end):
        raise BadCommandError("Invalid end of signature (trailing junk?)")

    # Check and extract the CRC
    if untrusted_armored_sig[-35:-33] != b"\n=":
        raise BadCommandError("Missing or misplaced CRC")
    untrusted_b64_crc24 = untrusted_armored_sig[-33:-29]
    check_base64(untrusted_b64_crc24)
    untrusted_crc24 = a2b_base64(untrusted_b64_crc24)
    del untrusted_b64_crc24

    # Parse the base64 data
    untrusted_base64_body = untrusted_armored_sig[len(sig_start) : -35].replace(
        b"\n", b""
    )
    sig_base64_len = len(untrusted_base64_body)
    if sig_base64_len < 88:
        raise BadCommandError("sig too short")
    if sig_base64_len % 4:
        raise BadCommandError("base64-encoded data has length not a multiple of 4")
    if untrusted_base64_body[-1] == ord("="):
        if untrusted_base64_body[-2] == ord("="):
            if untrusted_base64_body[-3] not in b'AQgw':
                raise BadCommandError("bad base64 line: bad byte before padding")
            check_base64(untrusted_base64_body[:-2])
        else:
            if untrusted_base64_body[-2] not in b'AEIMQUYcgkosw048':
                raise BadCommandError("bad base64 line: bad byte before padding")
            check_base64(untrusted_base64_body[:-1])
    else:
        check_base64(untrusted_base64_body)
    untrusted_sig_to_write = a2b_base64(untrusted_base64_body)

    # Check that the signature is a single packet of the correct type
    check_binary_signature(untrusted_sig_to_write, hash_str)

    # And reconstruct the armored signature.
    return reconstruct_armored_sig(untrusted_crc24, untrusted_sig_to_write)


def main(command_file, sig_file) -> None:
    untrusted_input = sys.stdin.buffer.read(8192)
    if sys.stdin.buffer.read(1):
        raise BadCommandError("Command too long (limit is 8192 bytes)")
    if len(untrusted_input) < len(shortest_input):
        raise BadCommandError("input too short")
    if not untrusted_input.startswith(prefix):
        raise BadCommandError("not an inline signed message")
    if not untrusted_input.endswith(b"\n"):
        raise BadCommandError("No trailing newline (webhook bug?)")
    if untrusted_input[41:49] not in (b"SHA256\n\n", b"SHA384\n\n", b"SHA512\n\n"):
        raise BadCommandError("bad hash algorithm")
    hash_str = untrusted_input[41:47]

    untrusted_command, untrusted_sig = untrusted_input[49:].split(b"\n", 1)
    check_command(untrusted_command)

    reconstructed_sig = canonicalize_sig(untrusted_sig, hash_str)
    with open(sig_file, "wb") as s, open(command_file, "wb") as c:
        s.write(reconstructed_sig)
        s.flush()
        c.write(untrusted_command)
        c.flush()


if __name__ == "__main__":
    try:
        main(sys.argv[1], sys.argv[2])
    except (BadCommandError, IOError) as e:
        print(e, file=sys.stderr)
        sys.exit(1)
