116 lines
3.2 KiB
Python
116 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
def _utc_now() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def _parse_utc(ts: str) -> datetime:
|
|
# Accept ISO8601 with trailing Z
|
|
s = ts.strip()
|
|
if s.endswith("Z"):
|
|
s = s[:-1] + "+00:00"
|
|
return datetime.fromisoformat(s).astimezone(timezone.utc)
|
|
|
|
|
|
def _iso_z(dt: datetime) -> str:
|
|
return dt.astimezone(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class HeadRecord:
|
|
cluster_name: str
|
|
head_ip: str
|
|
gcs_port: int
|
|
dashboard_port: int
|
|
job_server_url: str
|
|
updated_at: str
|
|
expires_at: str
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict[str, Any]) -> "HeadRecord":
|
|
return HeadRecord(
|
|
cluster_name=str(d["cluster_name"]),
|
|
head_ip=str(d["head_ip"]),
|
|
gcs_port=int(d.get("gcs_port", 6379)),
|
|
dashboard_port=int(d.get("dashboard_port", 8265)),
|
|
job_server_url=str(d.get("job_server_url") or f"http://{d['head_ip']}:{int(d.get('dashboard_port', 8265))}"),
|
|
updated_at=str(d["updated_at"]),
|
|
expires_at=str(d["expires_at"]),
|
|
)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"cluster_name": self.cluster_name,
|
|
"head_ip": self.head_ip,
|
|
"gcs_port": self.gcs_port,
|
|
"dashboard_port": self.dashboard_port,
|
|
"job_server_url": self.job_server_url,
|
|
"updated_at": self.updated_at,
|
|
"expires_at": self.expires_at,
|
|
}
|
|
|
|
def head_addr(self) -> str:
|
|
return f"{self.head_ip}:{self.gcs_port}"
|
|
|
|
|
|
def build_head_record(
|
|
*,
|
|
cluster_name: str,
|
|
head_ip: str,
|
|
gcs_port: int = 6379,
|
|
dashboard_port: int = 8265,
|
|
ttl_s: int = 60,
|
|
now: datetime | None = None,
|
|
) -> HeadRecord:
|
|
now_dt = now or _utc_now()
|
|
expires = now_dt + timedelta(seconds=int(ttl_s))
|
|
updated_at = _iso_z(now_dt)
|
|
expires_at = _iso_z(expires)
|
|
return HeadRecord(
|
|
cluster_name=cluster_name,
|
|
head_ip=head_ip,
|
|
gcs_port=int(gcs_port),
|
|
dashboard_port=int(dashboard_port),
|
|
job_server_url=f"http://{head_ip}:{int(dashboard_port)}",
|
|
updated_at=updated_at,
|
|
expires_at=expires_at,
|
|
)
|
|
|
|
|
|
def load_head_record(path: str, *, now: datetime | None = None) -> HeadRecord | None:
|
|
p = Path(path)
|
|
if not p.exists():
|
|
return None
|
|
try:
|
|
obj = json.loads(p.read_text(encoding="utf-8"))
|
|
except Exception:
|
|
return None
|
|
if not isinstance(obj, dict):
|
|
return None
|
|
try:
|
|
rec = HeadRecord.from_dict(obj)
|
|
expires = _parse_utc(rec.expires_at)
|
|
except Exception:
|
|
return None
|
|
now_dt = now or _utc_now()
|
|
if expires <= now_dt:
|
|
return None
|
|
return rec
|
|
|
|
|
|
def write_head_record_atomic(path: str, rec: HeadRecord) -> None:
|
|
p = Path(path)
|
|
os.makedirs(p.parent, exist_ok=True)
|
|
tmp = p.with_suffix(p.suffix + ".tmp")
|
|
tmp.write_text(json.dumps(rec.to_dict(), indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
|
tmp.replace(p)
|
|
|