#!/usr/bin/python3
#
# The Qubes OS Project, https://www.qubes-os.org/
#
# Copyright (C) 2023 Marek Marczykowski-Górecki
#                           <marmarek@invisiblethingslab.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with this library; if not, see <https://www.gnu.org/licenses/>.

import argparse
import hashlib
from typing import List, Dict, Type

from debian.deb822 import Deb822, Changes, Dsc

try:
    from debian.deb822 import BuildInfo
except ImportError:
    # older 'debian' package doesn't have BuildInfo class,
    # but it's easy enough to add
    # pylint: disable=protected-access
    from debian.deb822 import _gpg_multivalued

    class BuildInfo(_gpg_multivalued):
        _multivalued_fields: Dict[str, List[str]] = {
            "checksums-md5": ["md5", "size", "name"],
            "checksums-sha1": ["sha1", "size", "name"],
            "checksums-sha256": ["sha256", "size", "name"],
            "checksums-sha512": ["sha512", "size", "name"],
        }


from pathlib import Path


def get_checksums(file_path: Path):
    """
    Get checksums needed for .changes file
    """
    checksums = {
        "size": str(file_path.stat().st_size),
    }
    for digest in ("md5sum", "sha1", "sha256", "sha512"):
        with file_path.open("rb") as file_obj:
            # TODO: consider hashlib.file_digest in py3.11
            if digest == "md5sum":
                # python-debian uses 'md5sum' for 'md5' digest in hashlib
                hash_obj = hashlib.new("md5")
            else:
                hash_obj = hashlib.new(digest)
            buf = file_obj.read(65536)
            while buf:
                hash_obj.update(buf)
                buf = file_obj.read(65536)
            checksums[digest] = hash_obj.hexdigest()
            if digest == "md5sum":
                # duplicate hash as 'md5' too, for BuildInfo naming
                checksums["md5"] = hash_obj.hexdigest()
    return checksums


def parse_dsc(dsc_path: Path):
    """
    Get checksums needed for .changes file
    """

    files_hashes = {}
    with dsc_path.open("r") as dsc_f:
        dsc = Dsc(dsc_f)
        for digest, section in (
            ("md5sum", "Files"),
            ("sha1", "Checksums-Sha1"),
            ("sha256", "Checksums-Sha256"),
            ("sha512", "Checksums-Sha512"),
        ):
            if section not in dsc:
                continue
            for f_line in dsc[section]:
                files_hashes.setdefault(f_line["name"], {})[digest] = f_line[digest]
                if digest == "md5sum":
                    # duplicate hash as 'md5' too, for BuildInfo naming
                    files_hashes.setdefault(f_line["name"], {})["md5"] = f_line[digest]
                    files_hashes[f_line["name"]]["size"] = f_line["size"]

    return files_hashes


def patch_checksums(
    changes_path: Path, klass: Type[Deb822], files_hashes: dict[str, dict[str, str]]
):
    """
    Update file hashes in 'changes' or 'buildinfo' file.
    """
    with changes_path.open("rb") as changes_f:
        changes = klass(changes_f)
        for digest, section in (
            ("md5sum", "Files"),  # changes
            ("md5", "Checksums-Md5"),  # buildinfo
            ("sha1", "Checksums-Sha1"),
            ("sha256", "Checksums-Sha256"),
            ("sha512", "Checksums-Sha512"),
        ):
            if section not in changes:
                continue
            for f_line in changes[section]:
                if f_line["name"] not in files_hashes:
                    continue
                f_line[digest] = files_hashes[f_line["name"]][digest]
                f_line["size"] = files_hashes[f_line["name"]]["size"]

    with changes_path.open("wb") as changes_f:
        changes.dump(changes_f)


def main():
    """
    Main function
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("orig_dsc_path", type=Path)
    parser.add_argument("buildinfo_path", type=Path)
    parser.add_argument("changes_path", type=Path)
    args = parser.parse_args()

    files_hashes = parse_dsc(args.orig_dsc_path)
    # get hashes of the dsc itself
    files_hashes[args.orig_dsc_path.name] = get_checksums(args.orig_dsc_path)
    patch_checksums(args.buildinfo_path, BuildInfo, files_hashes)
    files_hashes[args.buildinfo_path.name] = get_checksums(args.buildinfo_path)
    patch_checksums(args.changes_path, Changes, files_hashes)


if __name__ == "__main__":
    main()
