254 lines
6.7 KiB
Python
254 lines
6.7 KiB
Python
import sqlite3
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from cryptography.fernet import Fernet
|
|
|
|
import pytest
|
|
|
|
from exporter.config import DeviceConfig
|
|
from exporter.sqlite_store import PasswordEncryptor, SQLiteDeviceStore
|
|
|
|
|
|
@pytest.fixture
|
|
def encryptor() -> PasswordEncryptor:
|
|
key = Fernet.generate_key().decode()
|
|
return PasswordEncryptor(key)
|
|
|
|
|
|
def test_init_db_creates_devices_table(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "test.db"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor)
|
|
store.init_db()
|
|
conn = sqlite3.connect(str(db_path))
|
|
row = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='devices'"
|
|
).fetchone()
|
|
conn.close()
|
|
assert row is not None
|
|
|
|
|
|
def test_save_and_load_device(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "test.db"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor)
|
|
store.init_db()
|
|
|
|
dev = DeviceConfig(
|
|
name="dev1",
|
|
host="1.1.1.1",
|
|
port=830,
|
|
username="u",
|
|
password="p",
|
|
enabled=True,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev)
|
|
loaded = store.load_runtime_devices()
|
|
assert len(loaded) == 1
|
|
assert loaded[0].name == "dev1"
|
|
assert loaded[0].password == "p"
|
|
assert loaded[0].source == "runtime"
|
|
|
|
|
|
def test_password_not_stored_in_plaintext(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "test.db"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor)
|
|
store.init_db()
|
|
|
|
dev = DeviceConfig(
|
|
name="dev1",
|
|
host="h",
|
|
port=830,
|
|
username="u",
|
|
password="plain",
|
|
enabled=True,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev)
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
cipher = conn.execute(
|
|
"SELECT password_cipher FROM devices WHERE name='dev1'"
|
|
).fetchone()[0]
|
|
conn.close()
|
|
|
|
assert cipher != b"plain"
|
|
assert encryptor.decrypt(cipher) == "plain"
|
|
|
|
|
|
def test_concurrent_writes_are_serialized(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "devices.db"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor, timeout=1.0)
|
|
store.init_db()
|
|
|
|
errors: list[Exception] = []
|
|
|
|
def worker(i: int) -> None:
|
|
try:
|
|
dev = DeviceConfig(
|
|
name=f"dev-{i}",
|
|
host="h",
|
|
port=830,
|
|
username="u",
|
|
password="p",
|
|
enabled=True,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev)
|
|
except Exception as e: # noqa: BLE001
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert errors == []
|
|
loaded = store.load_runtime_devices()
|
|
assert {d.name for d in loaded} == {f"dev-{i}" for i in range(10)}
|
|
|
|
|
|
def test_each_operation_closes_connection(tmp_path: Path, encryptor: PasswordEncryptor, monkeypatch):
|
|
db_path = tmp_path / "devices.db"
|
|
|
|
real_connect = sqlite3.connect
|
|
connect_calls: list[sqlite3.Connection] = []
|
|
close_calls: list[sqlite3.Connection] = []
|
|
|
|
class ConnWrapper:
|
|
def __init__(self, inner: sqlite3.Connection) -> None:
|
|
self._inner = inner
|
|
connect_calls.append(self)
|
|
|
|
def execute(self, *args, **kwargs):
|
|
return self._inner.execute(*args, **kwargs)
|
|
|
|
def commit(self) -> None:
|
|
return self._inner.commit()
|
|
|
|
def rollback(self) -> None:
|
|
return self._inner.rollback()
|
|
|
|
def close(self) -> None:
|
|
close_calls.append(self)
|
|
return self._inner.close()
|
|
|
|
def tracked_connect(*args, **kwargs):
|
|
conn = real_connect(*args, **kwargs)
|
|
return ConnWrapper(conn)
|
|
|
|
monkeypatch.setattr(sqlite3, "connect", tracked_connect)
|
|
|
|
store = SQLiteDeviceStore(str(db_path), encryptor, timeout=1.0)
|
|
store.init_db()
|
|
|
|
for i in range(5):
|
|
dev = DeviceConfig(
|
|
name=f"d{i}",
|
|
host="h",
|
|
port=830,
|
|
username="u",
|
|
password="p",
|
|
enabled=True,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev)
|
|
store.load_runtime_devices()
|
|
|
|
assert len(connect_calls) == len(close_calls)
|
|
|
|
|
|
def test_delete_existing_device(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "devices.db"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor)
|
|
store.init_db()
|
|
|
|
dev = DeviceConfig(
|
|
name="dev1",
|
|
host="h",
|
|
port=830,
|
|
username="u",
|
|
password="p",
|
|
enabled=True,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev)
|
|
store.delete_device("dev1")
|
|
loaded = store.load_runtime_devices()
|
|
assert all(d.name != "dev1" for d in loaded)
|
|
|
|
|
|
def test_delete_non_existing_device(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "devices.db"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor)
|
|
store.init_db()
|
|
|
|
# 删除不存在的设备应为 no-op
|
|
store.delete_device("non-exist")
|
|
store.load_runtime_devices() # 不应抛异常
|
|
|
|
|
|
def test_created_at_preserved_and_updated_at_changes(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "devices.db"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor)
|
|
store.init_db()
|
|
|
|
dev1 = DeviceConfig(
|
|
name="dev1",
|
|
host="h1",
|
|
port=830,
|
|
username="u",
|
|
password="p",
|
|
enabled=True,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev1)
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
created1, updated1 = conn.execute(
|
|
"SELECT created_at, updated_at FROM devices WHERE name='dev1'"
|
|
).fetchone()
|
|
|
|
time.sleep(1)
|
|
dev2 = DeviceConfig(
|
|
name="dev1",
|
|
host="h2",
|
|
port=830,
|
|
username="u",
|
|
password="p2",
|
|
enabled=False,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev2)
|
|
created2, updated2, host = conn.execute(
|
|
"SELECT created_at, updated_at, host FROM devices WHERE name='dev1'"
|
|
).fetchone()
|
|
conn.close()
|
|
|
|
assert created2 == created1
|
|
assert updated2 > updated1
|
|
assert host == "h2"
|
|
|
|
|
|
def test_supports_xpath_persisted(tmp_path: Path, encryptor: PasswordEncryptor):
|
|
db_path = tmp_path / "db.sqlite"
|
|
store = SQLiteDeviceStore(str(db_path), encryptor)
|
|
store.init_db()
|
|
|
|
dev = DeviceConfig(
|
|
name="dev1",
|
|
host="h",
|
|
port=830,
|
|
username="u",
|
|
password="p",
|
|
enabled=True,
|
|
supports_xpath=True,
|
|
source="runtime",
|
|
)
|
|
store.save_device(dev)
|
|
loaded = store.load_runtime_devices()
|
|
assert len(loaded) == 1
|
|
assert loaded[0].supports_xpath is True
|