#!/usr/bin/env python3

# Copyright: 2026 Hector CAO <hector.cao@canonical.com>
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

import argparse
import subprocess
from dataclasses import dataclass, field
from pathlib import Path


DESCRIPTION_PREFIX_LINES = [
    "This package is part of the Ubuntu Virtualization",
    "Hardware Enablement variant, which provides the",
    "virtualization stack of future Ubuntu releases back to",
    "the Ubuntu LTS. The package names are suffixed with",
    "-hwe in comparison to the equivalent base packages.",
    "However, some documentation and manual pages may still",
    "refer to the base package names. In those cases, the",
    "base and -hwe package names should be considered",
    "interchangeable. The -hwe packages can fulfill",
    "dependencies to the base packages, no package is",
    "expected to directly depend on a -hwe suffixed package",
    "to keep them interchangeable.",
    ".",
]


@dataclass
class PackageControl:
    package: str
    fields: dict[str, list[str]] = field(default_factory=dict)

    @classmethod
    def from_file(cls, control_path: Path) -> "PackageControl":
        parsed_fields: dict[str, list[str]] = {}
        current_field: str | None = None

        for line in control_path.read_text(encoding="utf-8").splitlines():
            if not line.strip():
                current_field = None
                continue

            if line[0] in " \t":
                if current_field is None:
                    continue
                continuation_value = line.strip()
                if continuation_value:
                    parsed_fields.setdefault(current_field, []).append(continuation_value)
                continue

            if ":" not in line:
                current_field = None
                continue

            key, value = line.split(":", 1)
            current_field = key.strip().lower()
            field_value = value.strip()
            parsed_fields.setdefault(current_field, [])
            if field_value:
                parsed_fields[current_field].append(field_value)

        package = parsed_fields.get("package", [control_path.parent.name])[0]
        return cls(package=package, fields=parsed_fields)

    def relation_packages(self, field_name: str) -> set[str]:
        entries: set[str] = set()

        for raw_values in self.fields.get(field_name.lower(), []):
            for relation in raw_values.split(","):
                relation = relation.strip()
                if not relation:
                    continue

                # Strip version/arch qualifiers to compare package names only.
                package_name = relation.split()[0].split(":", 1)[0]
                if package_name:
                    entries.add(package_name)

        return entries


def collect_binary_packages() -> list[str]:
    result = subprocess.run(
        ["dh_listpackages"],
        check=True,
        text=True,
        capture_output=True,
    )
    return [line.strip() for line in result.stdout.splitlines() if line.strip()]


def run_doc_symlinks() -> int:
    try:
        packages = collect_binary_packages()
    except FileNotFoundError:
        print("[error] dh_listpackages not found in PATH")
        return 2
    except subprocess.CalledProcessError as error:
        if error.stderr:
            print(error.stderr.rstrip())
        print(f"[error] dh_listpackages failed with exit code {error.returncode}")
        return 2

    for package_name in packages:
        if not package_name.endswith("-hwe"):
            print(f"skipped package: {package_name} (not an hwe package)")
            continue

        base_name = package_name.removesuffix("-hwe")
        base_doc_dir = Path("debian") / base_name / "usr/share/doc" / base_name
        hwe_doc_dir = Path("debian") / package_name / "usr/share/doc" / package_name

        if not base_doc_dir.is_dir():
            print(f"skipped package: {package_name} (missing base doc dir: {base_doc_dir})")
            continue

        if not hwe_doc_dir.is_dir():
            print(f"skipped package: {package_name} (missing hwe doc dir: {hwe_doc_dir})")
            continue

        for source_entry in base_doc_dir.iterdir():
            dest_path = hwe_doc_dir / source_entry.name

            # Match shell behavior: skip if file exists or if a broken symlink exists.
            if dest_path.exists() or dest_path.is_symlink():
                print(f"skipped entry: {dest_path} (already exists)")
                continue

            link_target = f"../{base_name}/{source_entry.name}"
            dest_path.symlink_to(link_target)
            print(f"created symlink: {dest_path} -> {link_target}")

    return 0

# all binary packages should be prefixed with -hwe
def run_check_names(controls: list[PackageControl]) -> int:
    invalid = [c.package for c in controls if not c.package.endswith("-hwe")]

    if invalid:
        print("[error] Packages missing '-hwe' suffix:")
        for package in invalid:
            print(f"  - {package}")
        return 1

    print(f"[ok] All {len(controls)} package(s) end with '-hwe'")
    return 0

# no package can have -hwe in its relationship except the ubuntu-virt-hwe package
def run_check_only_deps_hwe(controls: list[PackageControl]) -> int:
    relation_fields = ("Depends", "Pre-Depends", "Recommends", "Suggests", "Enhances")
    failed = False
    for control in controls:
        for field_name in relation_fields:
            packages = control.relation_packages(field_name)
            violations = {p for p in packages if p.endswith("-hwe") and p != "ubuntu-virt-hwe"}
            if violations:
                print(
                    f"[error] {control.package} {field_name} on -hwe package(s): {', '.join(sorted(violations))}"
                )
                failed = True

    return 1 if failed else 0

# all binary packages should Replaces base package without -hwe suffix
# and Provides the same virtual package as the base package
# and Depends on ubuntu-virt-hwe
def run_check_deps_base(controls: list[PackageControl]) -> int:
    failed = False
    for control in controls:
        if control.package == "ubuntu-virt-hwe":
            continue
        if not control.package.endswith("-hwe"):
            print(f"skipped package: {control.package} (not an hwe package)")
            continue

        base_name = control.package.removesuffix("-hwe")
        replaces = control.relation_packages("Replaces")
        provides = control.relation_packages("Provides")
        depends = control.relation_packages("Depends")

        if base_name not in replaces or base_name not in provides or "ubuntu-virt-hwe" not in depends:
            print(f"[error] {control.package} missing Replaces or Provides on base package")
            failed = True

    return 1 if failed else 0

def run_check() -> int:
    try:
        packages = collect_binary_packages()
    except FileNotFoundError:
        print("[error] dh_listpackages not found in PATH")
        return 1
    except subprocess.CalledProcessError as error:
        if error.stderr:
            print(error.stderr.rstrip())
        print(f"[error] dh_listpackages failed with exit code {error.returncode}")
        return 1

    if not packages:
        print("[error] dh_listpackages returned no binary packages")
        return 1

    controls = []
    for package_name in packages:
        control_path = Path("debian") / package_name / "DEBIAN/control"
        # package might be unvailable for the current architecture
        if not control_path.is_file():
            print(f"skipped package: {package_name} (missing control file: {control_path})")
            continue
        controls.append(PackageControl.from_file(control_path))

    name_status = run_check_names(controls)
    deps_status = run_check_deps_base(controls)
    deps_only_hwe_status = run_check_only_deps_hwe(controls)
    return 1 if name_status != 0 or deps_status != 0 or deps_only_hwe_status != 0 else 0

def prepend_description(control_path: Path) -> bool:
    lines = control_path.read_text(encoding="utf-8").splitlines(keepends=True)
    insert_lines = [f" {line}\n" for line in DESCRIPTION_PREFIX_LINES]

    for index, line in enumerate(lines):
        if not line.startswith("Description:"):
            continue

        if not line.endswith("\n"):
            lines[index] = line + "\n"

        insert_index = index + 1
        existing_block = lines[insert_index : insert_index + len(insert_lines)]
        if existing_block == insert_lines:
            return True

        lines[insert_index:insert_index] = insert_lines

        control_path.write_text("".join(lines), encoding="utf-8")
        return True

    return False


def run_add_description() -> int:
    try:
        packages = collect_binary_packages()
    except FileNotFoundError:
        print("[error] dh_listpackages not found in PATH")
        return 2
    except subprocess.CalledProcessError as error:
        if error.stderr:
            print(error.stderr.rstrip())
        print(f"[error] dh_listpackages failed with exit code {error.returncode}")
        return 2

    status = 0
    for package_name in packages:
        control_path = Path("debian") / package_name / "DEBIAN/control"

        # package might be unvailable for the current architecture
        if not control_path.is_file():
            print(f"skipped package: {package_name} (missing control file: {control_path})")
            continue

        if prepend_description(control_path):
            print(f"updated description: {control_path}")
            continue

        print(f"[error] Description field not found in: {control_path}")
        status = 1

    return status


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="HWE helper utilities for Debian packaging."
    )
    subparsers = parser.add_subparsers(dest="action", required=True)

    subparsers.add_parser(
        "doc-symlinks",
        help="Create missing HWE doc symlinks using dh_listpackages output.",
    )
    subparsers.add_parser(
        "check",
        help="Check that all binary package names from dh_listpackages end with '-hwe'.",
    )
    subparsers.add_parser(
        "add-description",
        help="Add description for -hwe package.",
    )

    return parser.parse_args()


def main() -> int:
    args = parse_args()

    if args.action == "doc-symlinks":
        return run_doc_symlinks()
    if args.action == "check":
        return run_check()
    if args.action == "add-description":
        return run_add_description()

    print(f"[error] Unsupported action: {args.action}")
    return 2


if __name__ == "__main__":
    raise SystemExit(main())
