from __future__ import annotations import json import sqlite3 import threading from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple from .models import serialize_node_row, serialize_node_summary from .util import ensure_parent, to_iso, utcnow class Storage: def __init__(self, db_path: str, node_id_prefix: str) -> None: self._db_path = db_path self._node_id_prefix = node_id_prefix ensure_parent(db_path) self._lock = threading.Lock() self._conn = sqlite3.connect(db_path, detect_types=sqlite3.PARSE_DECLTYPES, check_same_thread=False) self._conn.row_factory = sqlite3.Row with self._lock: self._conn.execute("PRAGMA foreign_keys = ON;") self._ensure_schema() # ------------------------------------------------------------------ # schema & helpers # ------------------------------------------------------------------ def _ensure_schema(self) -> None: """初始化表结构,确保服务启动时数据库结构就绪。""" with self._lock: self._conn.executescript( """ CREATE TABLE IF NOT EXISTS nodes ( id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, type TEXT NOT NULL, version TEXT, status TEXT NOT NULL, config_json TEXT, labels_json TEXT, meta_json TEXT, health_json TEXT, register_time TEXT, last_report TEXT, agent_last_report TEXT, last_updated TEXT ); CREATE TABLE IF NOT EXISTS kv ( key TEXT PRIMARY KEY, value TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_nodes_status ON nodes(status); CREATE INDEX IF NOT EXISTS idx_nodes_name ON nodes(name); """ ) self._conn.commit() def close(self) -> None: with self._lock: self._conn.close() # ------------------------------------------------------------------ # Node ID allocation # ------------------------------------------------------------------ def allocate_node_id(self) -> str: """在 kv 表里维护自增序列,为新节点生成形如 A1 的 ID。""" with self._lock: cur = self._conn.execute("SELECT value FROM kv WHERE key = ?", ("node_id_seq",)) row = cur.fetchone() if row is None: next_id = 1 self._conn.execute("INSERT INTO kv(key, value) VALUES(?, ?)", ("node_id_seq", str(next_id))) else: next_id = int(row["value"]) + 1 self._conn.execute("UPDATE kv SET value = ? WHERE key = ?", (str(next_id), "node_id_seq")) self._conn.commit() return f"{self._node_id_prefix}{next_id}" # ------------------------------------------------------------------ # Query helpers # ------------------------------------------------------------------ def list_nodes(self) -> List[Dict[str, Any]]: with self._lock: cur = self._conn.execute( "SELECT id, name, status, type, version FROM nodes ORDER BY id ASC" ) rows = cur.fetchall() return [serialize_node_summary(row) for row in rows] def get_node(self, node_id: str) -> Optional[Dict[str, Any]]: with self._lock: cur = self._conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)) row = cur.fetchone() if row is None: return None return serialize_node_row(row) def get_node_raw(self, node_id: str) -> Optional[sqlite3.Row]: with self._lock: cur = self._conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)) row = cur.fetchone() return row def get_node_by_name(self, name: str) -> Optional[Dict[str, Any]]: with self._lock: cur = self._conn.execute("SELECT * FROM nodes WHERE name = ?", (name,)) row = cur.fetchone() if row is None: return None return serialize_node_row(row) # ------------------------------------------------------------------ # Mutation helpers # ------------------------------------------------------------------ def create_node( self, node_id: str, name: str, node_type: str, version: str | None, meta_data: Mapping[str, Any], status: str, register_time_iso: str, last_updated_iso: str, ) -> Dict[str, Any]: """插入节点初始记录,默认 config/label/health 为空。""" now_iso = last_updated_iso with self._lock: self._conn.execute( """ INSERT INTO nodes ( id, name, type, version, status, config_json, labels_json, meta_json, health_json, register_time, last_report, agent_last_report, last_updated ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( node_id, name, node_type, version, status, json.dumps({}), json.dumps([]), json.dumps(dict(meta_data)), json.dumps({}), register_time_iso, None, None, now_iso, ), ) self._conn.commit() created = self.get_node(node_id) if created is None: raise RuntimeError("Failed to read back created node") return created def update_node_meta( self, node_id: str, *, name: Optional[str] = None, node_type: Optional[str] = None, version: Optional[str | None] = None, meta_data: Optional[Mapping[str, Any]] = None, last_updated_iso: Optional[str] = None, ) -> Dict[str, Any]: """重注册时更新节点静态信息,缺省字段保持不变。""" updates: List[str] = [] params: List[Any] = [] if name is not None: updates.append("name = ?") params.append(name) if node_type is not None: updates.append("type = ?") params.append(node_type) if version is not None: updates.append("version = ?") params.append(version) if meta_data is not None: updates.append("meta_json = ?") params.append(json.dumps(dict(meta_data))) if last_updated_iso is not None: updates.append("last_updated = ?") params.append(last_updated_iso) if not updates: result = self.get_node(node_id) if result is None: raise KeyError(node_id) return result params.append(node_id) with self._lock: self._conn.execute( f"UPDATE nodes SET {', '.join(updates)} WHERE id = ?", tuple(params), ) self._conn.commit() updated = self.get_node(node_id) if updated is None: raise KeyError(node_id) return updated def update_config_and_labels( self, node_id: str, *, config: Optional[Mapping[str, Any]] = None, labels: Optional[Iterable[str]] = None ) -> Dict[str, Any]: """部分更新 config/label,并刷新 last_updated 时间戳。""" updates: List[str] = [] params: List[Any] = [] if config is not None: updates.append("config_json = ?") params.append(json.dumps(dict(config))) if labels is not None: updates.append("labels_json = ?") params.append(json.dumps(list(labels))) updates.append("last_updated = ?") params.append(to_iso(utcnow())) params.append(node_id) with self._lock: self._conn.execute( f"UPDATE nodes SET {', '.join(updates)} WHERE id = ?", tuple(params), ) if self._conn.total_changes == 0: self._conn.rollback() raise KeyError(node_id) self._conn.commit() updated = self.get_node(node_id) if updated is None: raise KeyError(node_id) return updated def update_last_report( self, node_id: str, *, server_timestamp_iso: str, agent_timestamp_iso: str, health: Mapping[str, Any], ) -> Dict[str, Any]: """记录最新上报时间和健康信息,用于后续状态计算。""" with self._lock: self._conn.execute( """ UPDATE nodes SET last_report = ?, agent_last_report = ?, health_json = ?, last_updated = ? WHERE id = ? """, ( server_timestamp_iso, agent_timestamp_iso, json.dumps(health), server_timestamp_iso, node_id, ), ) if self._conn.total_changes == 0: self._conn.rollback() raise KeyError(node_id) self._conn.commit() updated = self.get_node(node_id) if updated is None: raise KeyError(node_id) return updated def update_status(self, node_id: str, status: str, *, last_updated_iso: str) -> None: with self._lock: self._conn.execute( "UPDATE nodes SET status = ?, last_updated = ? WHERE id = ?", (status, last_updated_iso, node_id), ) self._conn.commit() # ------------------------------------------------------------------ # Reporting helpers # ------------------------------------------------------------------ def get_statistics(self) -> Dict[str, Any]: """统计节点总数及按状态聚合的数量。""" with self._lock: cur = self._conn.execute("SELECT COUNT(*) AS total FROM nodes") total_row = cur.fetchone() cur = self._conn.execute("SELECT status, COUNT(*) AS count FROM nodes GROUP BY status") status_rows = cur.fetchall() return { "total": total_row["total"] if total_row else 0, "status_statistics": [ {"status": row["status"], "count": row["count"]} for row in status_rows ], } def fetch_nodes_for_scheduler(self) -> List[sqlite3.Row]: with self._lock: cur = self._conn.execute( "SELECT id, last_report, status FROM nodes" ) return cur.fetchall() def get_online_nodes(self) -> List[Dict[str, Any]]: """返回在线节点列表,用于生成 nodes.json。""" with self._lock: cur = self._conn.execute( "SELECT id, meta_json, labels_json, name FROM nodes WHERE status = ? ORDER BY id ASC", ("online",), ) rows = cur.fetchall() result: List[Dict[str, Any]] = [] for row in rows: meta = json.loads(row["meta_json"]) if row["meta_json"] else {} labels = json.loads(row["labels_json"]) if row["labels_json"] else [] result.append( { "node_id": row["id"], "user_id": meta.get("user"), "ip": meta.get("ip"), "hostname": meta.get("hostname", row["name"]), "labels": labels if isinstance(labels, list) else [], } ) return result