#!/usr/bin/env python3
import os
import shutil
import sys
from pathlib import Path
from urllib.parse import urlparse


def real_rsync_bin() -> str:
    env = os.environ.get("REAL_RSYNC_BIN")
    if env:
        return env
    default = "/usr/bin/rsync"
    if Path(default).exists():
        return default
    found = shutil.which("rsync")
    if found:
        return found
    raise SystemExit("cir-rsync-wrapper: REAL_RSYNC_BIN is not set and rsync was not found")


def rewrite_arg(arg: str, mirror_root: str | None) -> str:
    if not arg.startswith("rsync://"):
        return arg
    if not mirror_root:
        raise SystemExit(
            "cir-rsync-wrapper: CIR_MIRROR_ROOT is required when an rsync:// source is present"
        )
    parsed = urlparse(arg)
    if parsed.scheme != "rsync" or not parsed.hostname:
        raise SystemExit(f"cir-rsync-wrapper: invalid rsync URI: {arg}")
    path = parsed.path.lstrip("/")
    local = Path(mirror_root).resolve() / parsed.hostname
    if path:
        local = local / path
    local_str = str(local)
    if local.exists() and local.is_dir() and not local_str.endswith("/"):
        local_str += "/"
    elif arg.endswith("/") and not local_str.endswith("/"):
        local_str += "/"
    return local_str


def filter_args(args: list[str]) -> list[str]:
    mirror_root = os.environ.get("CIR_MIRROR_ROOT")
    rewritten_any = any(arg.startswith("rsync://") for arg in args)
    out: list[str] = []
    i = 0
    while i < len(args):
        arg = args[i]
        if rewritten_any:
            if arg == "--address":
                i += 2
                continue
            if arg.startswith("--address="):
                i += 1
                continue
            if arg == "--contimeout":
                i += 2
                continue
            if arg.startswith("--contimeout="):
                i += 1
                continue
        out.append(rewrite_arg(arg, mirror_root))
        i += 1
    return out


def local_link_mode_enabled() -> bool:
    value = os.environ.get("CIR_LOCAL_LINK_MODE", "")
    return value.lower() in {"1", "true", "yes", "on"}


def extract_source_and_dest(args: list[str]) -> tuple[str, str]:
    expects_value = {
        "--timeout",
        "--min-size",
        "--max-size",
        "--include",
        "--exclude",
        "--compare-dest",
    }
    positionals: list[str] = []
    i = 0
    while i < len(args):
        arg = args[i]
        if arg in expects_value:
            i += 2
            continue
        if any(arg.startswith(prefix + "=") for prefix in expects_value):
            i += 1
            continue
        if arg.startswith("-"):
            i += 1
            continue
        positionals.append(arg)
        i += 1
    if len(positionals) < 2:
        raise SystemExit("cir-rsync-wrapper: expected source and destination arguments")
    return positionals[-2], positionals[-1]


def maybe_exec_local_link_sync(args: list[str], rewritten_any: bool) -> None:
    if not rewritten_any or not local_link_mode_enabled():
        return
    source, dest = extract_source_and_dest(args)
    if source.startswith("rsync://"):
        raise SystemExit("cir-rsync-wrapper: expected rewritten local source for CIR_LOCAL_LINK_MODE")
    helper = Path(__file__).with_name("cir-local-link-sync.py")
    cmd = [sys.executable, str(helper)]
    if "--delete" in args:
        cmd.append("--delete")
    cmd.extend([source, dest])
    os.execv(sys.executable, cmd)


def main() -> int:
    args = sys.argv[1:]
    rewritten_any = any(arg.startswith("rsync://") for arg in args)
    rewritten = filter_args(args)
    maybe_exec_local_link_sync(rewritten, rewritten_any)
    os.execv(real_rsync_bin(), [real_rsync_bin(), *rewritten])
    return 127


if __name__ == "__main__":
    raise SystemExit(main())
