from types import SimpleNamespace import pytest from exporter.config import DeviceConfig, GlobalConfig from exporter.connection import ConnectionManager class DummyManager: def __init__(self) -> None: self.transport = SimpleNamespace(set_keepalive=lambda interval: None) self.closed = False def close_session(self) -> None: self.closed = True @pytest.fixture def global_cfg() -> GlobalConfig: cfg = GlobalConfig() cfg.ssh_keepalive_seconds = 30 return cfg @pytest.fixture def device_cfg() -> DeviceConfig: return DeviceConfig( name="dev1", host="1.1.1.1", port=830, username="u", password="p", ) def test_acquire_session_uses_ncclient_connect(monkeypatch, global_cfg, device_cfg): called = {} def fake_connect(**kwargs): called.update(kwargs) return DummyManager() monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect) mgr = ConnectionManager(global_cfg) session = mgr.acquire_session(device_cfg) assert isinstance(session, DummyManager) assert called["host"] == device_cfg.host assert called["port"] == device_cfg.port assert called["username"] == device_cfg.username assert called["password"] == device_cfg.password def test_enable_ssh_keepalive_called(monkeypatch, global_cfg, device_cfg): class DummyTransport: def __init__(self) -> None: self.keepalive_interval = None def set_keepalive(self, interval: int) -> None: self.keepalive_interval = interval class ManagerWithTransport(DummyManager): def __init__(self) -> None: super().__init__() self.transport = DummyTransport() mgr_instance = ManagerWithTransport() def fake_connect(**kwargs): return mgr_instance monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect) cm = ConnectionManager(global_cfg) cm.acquire_session(device_cfg) assert mgr_instance.transport.keepalive_interval == global_cfg.ssh_keepalive_seconds def test_mark_session_invalid_triggers_reconnect(monkeypatch, global_cfg, device_cfg): mgr_instances = [] def fake_connect(**kwargs): mgr = DummyManager() mgr_instances.append(mgr) return mgr monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect) cm = ConnectionManager(global_cfg) sess1 = cm.acquire_session(device_cfg) assert len(mgr_instances) == 1 cm.mark_session_invalid(device_cfg.name) sess2 = cm.acquire_session(device_cfg) assert len(mgr_instances) == 2 assert sess2 is mgr_instances[1] # 原会话应被关闭 assert mgr_instances[0].closed is True def test_close_all_closes_all_sessions(monkeypatch, global_cfg): mgr_instances = [] def fake_connect(**kwargs): mgr = DummyManager() mgr_instances.append(mgr) return mgr monkeypatch.setattr("exporter.connection.ncclient.manager.connect", fake_connect) cm = ConnectionManager(global_cfg) dev1 = DeviceConfig(name="dev1", host="1.1.1.1", port=830, username="u", password="p") dev2 = DeviceConfig(name="dev2", host="2.2.2.2", port=830, username="u", password="p") cm.acquire_session(dev1) cm.acquire_session(dev2) assert len(mgr_instances) == 2 cm.close_all() assert mgr_instances[0].closed is True assert mgr_instances[1].closed is True