rpki/scripts/cir/cir-rsync-wrapper

128 lines
3.7 KiB
Python
Executable File

#!/usr/bin/env python3
import os
import shutil
import sys
from pathlib import Path
from urllib.parse import urlparse
def real_rsync_bin() -> str:
env = os.environ.get("REAL_RSYNC_BIN")
if env:
return env
default = "/usr/bin/rsync"
if Path(default).exists():
return default
found = shutil.which("rsync")
if found:
return found
raise SystemExit("cir-rsync-wrapper: REAL_RSYNC_BIN is not set and rsync was not found")
def rewrite_arg(arg: str, mirror_root: str | None) -> str:
if not arg.startswith("rsync://"):
return arg
if not mirror_root:
raise SystemExit(
"cir-rsync-wrapper: CIR_MIRROR_ROOT is required when an rsync:// source is present"
)
parsed = urlparse(arg)
if parsed.scheme != "rsync" or not parsed.hostname:
raise SystemExit(f"cir-rsync-wrapper: invalid rsync URI: {arg}")
path = parsed.path.lstrip("/")
local = Path(mirror_root).resolve() / parsed.hostname
if path:
local = local / path
local_str = str(local)
if local.exists() and local.is_dir() and not local_str.endswith("/"):
local_str += "/"
elif arg.endswith("/") and not local_str.endswith("/"):
local_str += "/"
return local_str
def filter_args(args: list[str]) -> list[str]:
mirror_root = os.environ.get("CIR_MIRROR_ROOT")
rewritten_any = any(arg.startswith("rsync://") for arg in args)
out: list[str] = []
i = 0
while i < len(args):
arg = args[i]
if rewritten_any:
if arg == "--address":
i += 2
continue
if arg.startswith("--address="):
i += 1
continue
if arg == "--contimeout":
i += 2
continue
if arg.startswith("--contimeout="):
i += 1
continue
out.append(rewrite_arg(arg, mirror_root))
i += 1
return out
def local_link_mode_enabled() -> bool:
value = os.environ.get("CIR_LOCAL_LINK_MODE", "")
return value.lower() in {"1", "true", "yes", "on"}
def extract_source_and_dest(args: list[str]) -> tuple[str, str]:
expects_value = {
"--timeout",
"--min-size",
"--max-size",
"--include",
"--exclude",
"--compare-dest",
}
positionals: list[str] = []
i = 0
while i < len(args):
arg = args[i]
if arg in expects_value:
i += 2
continue
if any(arg.startswith(prefix + "=") for prefix in expects_value):
i += 1
continue
if arg.startswith("-"):
i += 1
continue
positionals.append(arg)
i += 1
if len(positionals) < 2:
raise SystemExit("cir-rsync-wrapper: expected source and destination arguments")
return positionals[-2], positionals[-1]
def maybe_exec_local_link_sync(args: list[str], rewritten_any: bool) -> None:
if not rewritten_any or not local_link_mode_enabled():
return
source, dest = extract_source_and_dest(args)
if source.startswith("rsync://"):
raise SystemExit("cir-rsync-wrapper: expected rewritten local source for CIR_LOCAL_LINK_MODE")
helper = Path(__file__).with_name("cir-local-link-sync.py")
cmd = [sys.executable, str(helper)]
if "--delete" in args:
cmd.append("--delete")
cmd.extend([source, dest])
os.execv(sys.executable, cmd)
def main() -> int:
args = sys.argv[1:]
rewritten_any = any(arg.startswith("rsync://") for arg in args)
rewritten = filter_args(args)
maybe_exec_local_link_sync(rewritten, rewritten_any)
os.execv(real_rsync_bin(), [real_rsync_bin(), *rewritten])
return 127
if __name__ == "__main__":
raise SystemExit(main())