argus-netconf-exporter/tests/test_sqlite_store.py
2025-11-28 14:35:21 +08:00

254 lines
6.7 KiB
Python

import sqlite3
import threading
import time
from pathlib import Path
from cryptography.fernet import Fernet
import pytest
from exporter.config import DeviceConfig
from exporter.sqlite_store import PasswordEncryptor, SQLiteDeviceStore
@pytest.fixture
def encryptor() -> PasswordEncryptor:
key = Fernet.generate_key().decode()
return PasswordEncryptor(key)
def test_init_db_creates_devices_table(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "test.db"
store = SQLiteDeviceStore(str(db_path), encryptor)
store.init_db()
conn = sqlite3.connect(str(db_path))
row = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='devices'"
).fetchone()
conn.close()
assert row is not None
def test_save_and_load_device(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "test.db"
store = SQLiteDeviceStore(str(db_path), encryptor)
store.init_db()
dev = DeviceConfig(
name="dev1",
host="1.1.1.1",
port=830,
username="u",
password="p",
enabled=True,
source="runtime",
)
store.save_device(dev)
loaded = store.load_runtime_devices()
assert len(loaded) == 1
assert loaded[0].name == "dev1"
assert loaded[0].password == "p"
assert loaded[0].source == "runtime"
def test_password_not_stored_in_plaintext(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "test.db"
store = SQLiteDeviceStore(str(db_path), encryptor)
store.init_db()
dev = DeviceConfig(
name="dev1",
host="h",
port=830,
username="u",
password="plain",
enabled=True,
source="runtime",
)
store.save_device(dev)
conn = sqlite3.connect(str(db_path))
cipher = conn.execute(
"SELECT password_cipher FROM devices WHERE name='dev1'"
).fetchone()[0]
conn.close()
assert cipher != b"plain"
assert encryptor.decrypt(cipher) == "plain"
def test_concurrent_writes_are_serialized(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "devices.db"
store = SQLiteDeviceStore(str(db_path), encryptor, timeout=1.0)
store.init_db()
errors: list[Exception] = []
def worker(i: int) -> None:
try:
dev = DeviceConfig(
name=f"dev-{i}",
host="h",
port=830,
username="u",
password="p",
enabled=True,
source="runtime",
)
store.save_device(dev)
except Exception as e: # noqa: BLE001
errors.append(e)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert errors == []
loaded = store.load_runtime_devices()
assert {d.name for d in loaded} == {f"dev-{i}" for i in range(10)}
def test_each_operation_closes_connection(tmp_path: Path, encryptor: PasswordEncryptor, monkeypatch):
db_path = tmp_path / "devices.db"
real_connect = sqlite3.connect
connect_calls: list[sqlite3.Connection] = []
close_calls: list[sqlite3.Connection] = []
class ConnWrapper:
def __init__(self, inner: sqlite3.Connection) -> None:
self._inner = inner
connect_calls.append(self)
def execute(self, *args, **kwargs):
return self._inner.execute(*args, **kwargs)
def commit(self) -> None:
return self._inner.commit()
def rollback(self) -> None:
return self._inner.rollback()
def close(self) -> None:
close_calls.append(self)
return self._inner.close()
def tracked_connect(*args, **kwargs):
conn = real_connect(*args, **kwargs)
return ConnWrapper(conn)
monkeypatch.setattr(sqlite3, "connect", tracked_connect)
store = SQLiteDeviceStore(str(db_path), encryptor, timeout=1.0)
store.init_db()
for i in range(5):
dev = DeviceConfig(
name=f"d{i}",
host="h",
port=830,
username="u",
password="p",
enabled=True,
source="runtime",
)
store.save_device(dev)
store.load_runtime_devices()
assert len(connect_calls) == len(close_calls)
def test_delete_existing_device(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "devices.db"
store = SQLiteDeviceStore(str(db_path), encryptor)
store.init_db()
dev = DeviceConfig(
name="dev1",
host="h",
port=830,
username="u",
password="p",
enabled=True,
source="runtime",
)
store.save_device(dev)
store.delete_device("dev1")
loaded = store.load_runtime_devices()
assert all(d.name != "dev1" for d in loaded)
def test_delete_non_existing_device(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "devices.db"
store = SQLiteDeviceStore(str(db_path), encryptor)
store.init_db()
# 删除不存在的设备应为 no-op
store.delete_device("non-exist")
store.load_runtime_devices() # 不应抛异常
def test_created_at_preserved_and_updated_at_changes(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "devices.db"
store = SQLiteDeviceStore(str(db_path), encryptor)
store.init_db()
dev1 = DeviceConfig(
name="dev1",
host="h1",
port=830,
username="u",
password="p",
enabled=True,
source="runtime",
)
store.save_device(dev1)
conn = sqlite3.connect(str(db_path))
created1, updated1 = conn.execute(
"SELECT created_at, updated_at FROM devices WHERE name='dev1'"
).fetchone()
time.sleep(1)
dev2 = DeviceConfig(
name="dev1",
host="h2",
port=830,
username="u",
password="p2",
enabled=False,
source="runtime",
)
store.save_device(dev2)
created2, updated2, host = conn.execute(
"SELECT created_at, updated_at, host FROM devices WHERE name='dev1'"
).fetchone()
conn.close()
assert created2 == created1
assert updated2 > updated1
assert host == "h2"
def test_supports_xpath_persisted(tmp_path: Path, encryptor: PasswordEncryptor):
db_path = tmp_path / "db.sqlite"
store = SQLiteDeviceStore(str(db_path), encryptor)
store.init_db()
dev = DeviceConfig(
name="dev1",
host="h",
port=830,
username="u",
password="p",
enabled=True,
supports_xpath=True,
source="runtime",
)
store.save_device(dev)
loaded = store.load_runtime_devices()
assert len(loaded) == 1
assert loaded[0].supports_xpath is True