argus-netconf-exporter/tests/test_connection.py
2025-11-28 14:35:21 +08:00

126 lines
3.4 KiB
Python

from types import SimpleNamespace
import pytest
from exporter.config import DeviceConfig, GlobalConfig
from exporter.connection import ConnectionManager
class DummyManager:
def __init__(self) -> None:
self.transport = SimpleNamespace(set_keepalive=lambda interval: None)
self.closed = False
def close_session(self) -> None:
self.closed = True
@pytest.fixture
def global_cfg() -> GlobalConfig:
cfg = GlobalConfig()
cfg.ssh_keepalive_seconds = 30
return cfg
@pytest.fixture
def device_cfg() -> DeviceConfig:
return DeviceConfig(
name="dev1",
host="1.1.1.1",
port=830,
username="u",
password="p",
)
def test_acquire_session_uses_ncclient_connect(monkeypatch, global_cfg, device_cfg):
called = {}
def fake_connect(**kwargs):
called.update(kwargs)
return DummyManager()
monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect)
mgr = ConnectionManager(global_cfg)
session = mgr.acquire_session(device_cfg)
assert isinstance(session, DummyManager)
assert called["host"] == device_cfg.host
assert called["port"] == device_cfg.port
assert called["username"] == device_cfg.username
assert called["password"] == device_cfg.password
def test_enable_ssh_keepalive_called(monkeypatch, global_cfg, device_cfg):
class DummyTransport:
def __init__(self) -> None:
self.keepalive_interval = None
def set_keepalive(self, interval: int) -> None:
self.keepalive_interval = interval
class ManagerWithTransport(DummyManager):
def __init__(self) -> None:
super().__init__()
self.transport = DummyTransport()
mgr_instance = ManagerWithTransport()
def fake_connect(**kwargs):
return mgr_instance
monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect)
cm = ConnectionManager(global_cfg)
cm.acquire_session(device_cfg)
assert mgr_instance.transport.keepalive_interval == global_cfg.ssh_keepalive_seconds
def test_mark_session_invalid_triggers_reconnect(monkeypatch, global_cfg, device_cfg):
mgr_instances = []
def fake_connect(**kwargs):
mgr = DummyManager()
mgr_instances.append(mgr)
return mgr
monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect)
cm = ConnectionManager(global_cfg)
sess1 = cm.acquire_session(device_cfg)
assert len(mgr_instances) == 1
cm.mark_session_invalid(device_cfg.name)
sess2 = cm.acquire_session(device_cfg)
assert len(mgr_instances) == 2
assert sess2 is mgr_instances[1]
# 原会话应被关闭
assert mgr_instances[0].closed is True
def test_close_all_closes_all_sessions(monkeypatch, global_cfg):
mgr_instances = []
def fake_connect(**kwargs):
mgr = DummyManager()
mgr_instances.append(mgr)
return mgr
monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect)
cm = ConnectionManager(global_cfg)
dev1 = DeviceConfig(name="dev1", host="1.1.1.1", port=830, username="u", password="p")
dev2 = DeviceConfig(name="dev2", host="2.2.2.2", port=830, username="u", password="p")
cm.acquire_session(dev1)
cm.acquire_session(dev2)
assert len(mgr_instances) == 2
cm.close_all()
assert mgr_instances[0].closed is True
assert mgr_instances[1].closed is True