From 03c0ab0ec7bcbbb0e1d3d07dba7722dc046f8e0c Mon Sep 17 00:00:00 2001 From: "xiuting.xu" Date: Wed, 25 Mar 2026 10:08:40 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85rtr-client=EF=BC=8C=E5=AE=8C?= =?UTF-8?q?=E5=96=84tls=E8=BF=9E=E6=8E=A5=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 3 +- README.md | 68 +- data/vrps.txt | 4 + src/bin/rtr_debug_client/README.md | 272 ++++- src/bin/rtr_debug_client/main.rs | 408 ++++++- src/bin/rtr_debug_client/pretty.rs | 57 +- src/bin/rtr_debug_client/protocol.rs | 6 +- src/main.rs | 155 ++- src/rtr/cache.rs | 929 ---------------- src/rtr/cache/core.rs | 632 +++++++++++ src/rtr/cache/mod.rs | 14 + src/rtr/cache/model.rs | 392 +++++++ src/rtr/cache/ordering.rs | 311 ++++++ src/rtr/cache/store.rs | 195 ++++ src/rtr/mod.rs | 4 +- src/rtr/payload.rs | 96 +- src/rtr/pdu.rs | 422 ++++++-- src/rtr/server/config.rs | 10 +- src/rtr/server/connection.rs | 72 +- src/rtr/server/listener.rs | 100 +- src/rtr/server/service.rs | 19 +- src/rtr/server/tls.rs | 75 +- src/rtr/session.rs | 1099 ++++++++++++++++--- src/rtr/state.rs | 8 +- src/rtr/store.rs | 619 +++++++++++ src/rtr/store_db.rs | 310 ------ tests/common/test_helper.rs | 39 +- tests/fixtures/tls/client-bad.cnf | 14 + tests/fixtures/tls/client-bad.crt | 20 + tests/fixtures/tls/client-bad.key | 28 + tests/fixtures/tls/client-ca.crt | 19 + tests/fixtures/tls/client-ca.key | 28 + tests/fixtures/tls/client-good.cnf | 14 + tests/fixtures/tls/client-good.crt | 20 + tests/fixtures/tls/client-good.key | 28 + tests/fixtures/tls/server-dns.cnf | 14 + tests/fixtures/tls/server-dns.crt | 19 + tests/fixtures/tls/server-dns.key | 28 + tests/fixtures/tls/server.cnf | 14 + tests/fixtures/tls/server.crt | 20 + tests/fixtures/tls/server.key | 28 + tests/test_cache.rs | 831 +++++++++++++-- tests/test_pdu.rs | 204 ++++ tests/test_session.rs | 1482 ++++++++++++++++++++++++-- tests/test_store_db.rs | 226 +++- 45 files changed, 7521 insertions(+), 1835 deletions(-) create mode 100644 data/vrps.txt delete mode 100644 src/rtr/cache.rs create mode 100644 src/rtr/cache/core.rs create mode 100644 src/rtr/cache/mod.rs create mode 100644 src/rtr/cache/model.rs create mode 100644 src/rtr/cache/ordering.rs create mode 100644 src/rtr/cache/store.rs create mode 100644 src/rtr/store.rs delete mode 100644 src/rtr/store_db.rs create mode 100644 tests/fixtures/tls/client-bad.cnf create mode 100644 tests/fixtures/tls/client-bad.crt create mode 100644 tests/fixtures/tls/client-bad.key create mode 100644 tests/fixtures/tls/client-ca.crt create mode 100644 tests/fixtures/tls/client-ca.key create mode 100644 tests/fixtures/tls/client-good.cnf create mode 100644 tests/fixtures/tls/client-good.crt create mode 100644 tests/fixtures/tls/client-good.key create mode 100644 tests/fixtures/tls/server-dns.cnf create mode 100644 tests/fixtures/tls/server-dns.crt create mode 100644 tests/fixtures/tls/server-dns.key create mode 100644 tests/fixtures/tls/server.cnf create mode 100644 tests/fixtures/tls/server.crt create mode 100644 tests/fixtures/tls/server.key create mode 100644 tests/test_pdu.rs diff --git a/Cargo.toml b/Cargo.toml index 16e3f63..85d3734 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,4 +29,5 @@ tokio-rustls = "0.26" rustls = "0.23" rustls-pemfile = "2" rustls-pki-types = "1.14.0" -tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } \ No newline at end of file +socket2 = "0.5" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } diff --git a/README.md b/README.md index e89d904..853e4a5 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,73 @@ +# RPKI RTR Server +Default runtime target: Ubuntu/Linux. Windows is only used during development. -# 单元测试 +## Tests +```bash +cargo test ``` -cargo test -# 查看输出 +To show test output: + +```bash cargo test -- --nocapture ``` +## RTR Server + +The RTR server binary reads its runtime configuration from environment variables. +If an environment variable is not set, the built-in default from `src/main.rs` +is used. + +### Environment Variables + +| Variable | Description | Example | +| --- | --- | --- | +| `RPKI_RTR_ENABLE_TLS` | Enable TLS listener in addition to TCP. Accepts `true/false`, `1/0`, `yes/no`, `on/off`. | `true` | +| `RPKI_RTR_TCP_ADDR` | TCP bind address. | `0.0.0.0:3323` | +| `RPKI_RTR_TLS_ADDR` | TLS bind address. | `0.0.0.0:3324` | +| `RPKI_RTR_DB_PATH` | RTR RocksDB path. | `./rtr-db` | +| `RPKI_RTR_VRP_FILE` | Input VRP file path. | `./data/vrps.txt` | +| `RPKI_RTR_TLS_CERT_PATH` | TLS server certificate path. | `./certs/server.crt` | +| `RPKI_RTR_TLS_KEY_PATH` | TLS server private key path. | `./certs/server.key` | +| `RPKI_RTR_TLS_CLIENT_CA_PATH` | Client CA certificate path used to verify router certificates. | `./certs/client-ca.crt` | +| `RPKI_RTR_MAX_DELTA` | Maximum retained delta count. | `100` | +| `RPKI_RTR_REFRESH_INTERVAL_SECS` | VRP reload interval in seconds. | `300` | +| `RPKI_RTR_MAX_CONNECTIONS` | Maximum concurrent RTR connections. | `512` | +| `RPKI_RTR_NOTIFY_QUEUE_SIZE` | Broadcast queue size for serial notify events. | `1024` | +| `RPKI_RTR_TCP_KEEPALIVE_SECS` | TCP keepalive time in seconds. Set `0` to disable. | `60` | +| `RPKI_RTR_WARN_INSECURE_TCP` | Emit a warning when plain TCP is enabled. Accepts boolean values. | `true` | +| `RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN` | Strict mode: reject TLS server certificates that do not contain a `subjectAltName dNSName`. Accepts boolean values. | `false` | + +### Notes + +- Plain TCP should only be used on a trusted and controlled network. +- TLS mode requires client certificate authentication. +- In strict TLS server certificate mode, a server certificate without + `subjectAltName dNSName` will be rejected during startup. +- `RPKI_RTR_TCP_KEEPALIVE_SECS=0` disables TCP keepalive. Any non-zero value + enables keepalive for the lifetime of each accepted socket. + +## Example Startup + +### Bash + +```sh +export RPKI_RTR_ENABLE_TLS=true +export RPKI_RTR_TCP_ADDR=0.0.0.0:3323 +export RPKI_RTR_TLS_ADDR=0.0.0.0:3324 +export RPKI_RTR_DB_PATH=./rtr-db +export RPKI_RTR_VRP_FILE=./data/vrps.txt +export RPKI_RTR_TLS_CERT_PATH=./certs/server-dns.crt +export RPKI_RTR_TLS_KEY_PATH=./certs/server-dns.key +export RPKI_RTR_TLS_CLIENT_CA_PATH=./certs/client-ca.crt +export RPKI_RTR_TCP_KEEPALIVE_SECS=60 +export RPKI_RTR_WARN_INSECURE_TCP=true +export RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN=true + +cargo run +``` + +A ready-to-edit example script is provided at +[`scripts/start-rtr-server.sh`](/C:/Users/xuxiu/git_code/rpki/scripts/start-rtr-server.sh). diff --git a/data/vrps.txt b/data/vrps.txt new file mode 100644 index 0000000..b479756 --- /dev/null +++ b/data/vrps.txt @@ -0,0 +1,4 @@ +# prefix,max_len,asn +10.0.0.0/24,25,65001 +10.0.1.0/24,24,65022 +2001:db8::/32,64,65003 \ No newline at end of file diff --git a/src/bin/rtr_debug_client/README.md b/src/bin/rtr_debug_client/README.md index 9d95b97..8ad68a4 100644 --- a/src/bin/rtr_debug_client/README.md +++ b/src/bin/rtr_debug_client/README.md @@ -1,63 +1,235 @@ # rtr_debug_client -`rtr_debug_client` 是一个用于调试和联调 RTR(RPKI-to-Router)服务端的小型客户端工具。 +`rtr_debug_client` 是一个轻量级的 RTR 调试客户端,用于手工联调和协议行为观察。 -它的目标不是做一个完整的生产级 router client,而是提供一个简单、直接、可观察的调试入口,用于: +它适合以下场景: +- 在开发阶段验证 RTR server 的行为 +- 发送 `Reset Query` 和 `Serial Query` +- 观察服务端返回的各类 PDU +- 检查会话状态、`session_id`、`serial` 的变化 +- 排查 `ErrorReport`、`CacheReset`、`SerialNotify`、`RouterKey`、`ASPA` +- 联调纯 TCP 和 TLS 两种 RTR 传输方式 -- 连接 RTR server -- 发送 `Reset Query` 或 `Serial Query` -- 接收并打印服务端返回的 PDU -- 辅助排查协议实现、会话状态、序列号增量、PDU 编码等问题 +它不是生产级 router client,而是一个便于调试和观察协议细节的小工具。 ---- - -## 适用场景 - -这个工具适合以下场景: - -- 开发 RTR server 时做本地联调 -- 验证服务端是否正确返回 `Cache Response` -- 检查 `IPv4 Prefix` / `IPv6 Prefix` / `ASPA` / `End of Data` 等 PDU -- 验证 `Serial Query` 路径是否正确 -- 观察异常响应,例如 `Cache Reset` 或 `Error Report` -- 后续扩展为支持 TLS、自动断言、会话统计等调试能力 - ---- - -## 当前能力 +## 当前支持的能力 当前版本支持: - -- TCP 连接 RTR server +- 纯 TCP 连接 +- TLS 连接 +- TLS 服务端证书校验 +- 可选的 TLS 客户端证书认证 - 发送 `Reset Query` - 发送 `Serial Query` -- 持续读取服务端返回的 PDU -- 解析并打印以下常见 PDU: - - `Serial Notify` - - `Serial Query` - - `Reset Query` - - `Cache Response` - - `IPv4 Prefix` - - `IPv6 Prefix` - - `End of Data` - - `Cache Reset` - - `Error Report` - - `ASPA` -- 基础长度校验 -- 最大 PDU 长度限制,防止异常数据导致过大内存分配 +- 保持长连接持续接收服务端 PDU +- 格式化展示以下 PDU: + - `Serial Notify` + - `Serial Query` + - `Reset Query` + - `Cache Response` + - `IPv4 Prefix` + - `IPv6 Prefix` + - `Router Key` + - `ASPA` + - `End of Data` + - `Cache Reset` + - `Error Report` +- 结构化展示 `ErrorReport`: + - 错误码及语义名称 + - encapsulated PDU 的 header 摘要 + - encapsulated PDU 原始 hex + - arbitrary text 是否为 UTF-8 + - arbitrary text 内容 +- 根据 `EndOfData` 的 timing hint 自动轮询 +- 收到 `ErrorReport` 后默认暂停自动轮询 +- 通过 `--keep-after-error` 保持错误后的自动轮询 ---- +## 构建 -## 目录结构 +```sh +cargo build --bin rtr_debug_client +``` -建议目录如下: +## 基本用法 -```text -src/ -└── bin/ - └── rtr_debug_client/ - ├── main.rs - ├── protocol.rs - ├── io.rs - ├── pretty.rs - └── README.md \ No newline at end of file +基本形式: + +```sh +cargo run --bin rtr_debug_client -- [reset|serial ] [options] +``` + +默认值: +- `addr`: `127.0.0.1:3323` +- `version`: `1` +- `mode`: `reset` +- `timeout`: `30` +- `poll`: `60` + +## TCP 示例 + +发送 `Reset Query`: + +```sh +cargo run --bin rtr_debug_client -- 127.0.0.1:3323 1 reset +``` + +发送 `Serial Query`: + +```sh +cargo run --bin rtr_debug_client -- 127.0.0.1:3323 1 serial 42 100 +``` + +持续观察错误路径: + +```sh +cargo run --bin rtr_debug_client -- 127.0.0.1:3323 1 reset --keep-after-error +``` + +## TLS 示例 + +只做服务端证书校验: + +```sh +cargo run --bin rtr_debug_client -- \ + 127.0.0.1:3324 1 reset \ + --tls \ + --ca-cert tests/fixtures/tls/client-ca.crt \ + --server-name localhost +``` + +双向 TLS 认证: + +```sh +cargo run --bin rtr_debug_client -- \ + 127.0.0.1:3324 1 reset \ + --tls \ + --ca-cert tests/fixtures/tls/client-ca.crt \ + --server-name localhost \ + --client-cert tests/fixtures/tls/client-good.crt \ + --client-key tests/fixtures/tls/client-good.key +``` + +双向 TLS + 错误后继续自动轮询: + +```sh +cargo run --bin rtr_debug_client -- \ + 127.0.0.1:3324 1 reset \ + --tls \ + --ca-cert tests/fixtures/tls/client-ca.crt \ + --server-name localhost \ + --client-cert tests/fixtures/tls/client-good.crt \ + --client-key tests/fixtures/tls/client-good.key \ + --keep-after-error +``` + +说明: +- 开启 `--tls` 时必须提供 `--ca-cert` +- 如果目标地址本身不适合直接作为 TLS 名称,显式提供 `--server-name` +- 客户端认证必须同时提供 `--client-cert` 和 `--client-key` + +## 命令行参数 + +- `--tls` + 使用 TLS 而不是纯 TCP。 + +- `--ca-cert ` + 用于校验服务端证书的 CA 证书文件,PEM 格式。 + +- `--server-name ` + TLS 握手时用于校验证书的服务端名称。 + +- `--client-cert ` + 双向 TLS 时使用的客户端证书,PEM 格式。 + +- `--client-key ` + 与 `--client-cert` 配套的客户端私钥,PEM 格式。 + +- `--timeout ` + 等待下一个 PDU 的读取超时时间,单位秒。 + +- `--poll ` + 在尚未拿到 `EndOfData` timing hint 前,默认使用的自动轮询间隔。 + +- `--keep-after-error` + 收到 `ErrorReport` 后不暂停自动轮询。 + +## 运行中可用命令 + +程序启动后,可以在控制台输入以下命令: + +- `help` + 显示帮助。 + +- `state` + 打印当前客户端状态。 + +- `reset` + 发送 `Reset Query`。 + +- `serial` + 使用当前 `session_id` 和 `serial` 发送 `Serial Query`。 + +- `serial ` + 使用显式参数发送 `Serial Query`。 + +- `timeout` + 查看当前读取超时设置。 + +- `timeout ` + 修改读取超时。 + +- `poll` + 查看当前自动轮询间隔、轮询来源以及暂停状态。 + +- `poll ` + 手工覆盖当前轮询间隔。 + +- `poll pause` + 暂停自动轮询。 + +- `poll resume` + 恢复自动轮询。 + +- `keep-after-error` + 查看当前是否启用了错误后持续轮询。 + +- `quit` + 退出客户端。 + +## 自动轮询行为 + +客户端会保持连接,并周期性地向服务端发起下一次查询。 + +选择下一次轮询间隔的优先级如下: +1. `retry`,当最近一次 `ErrorReport` 是 `No Data Available` 或 `Transport Failure` +2. `refresh`,如果已经从 `EndOfData` 中拿到 +3. 启动参数里的默认轮询间隔 + +收到 `ErrorReport` 后的默认行为: +- 默认暂停自动轮询 +- 连接保持不关,方便继续观察 +- 你可以手工输入 `reset`、`serial` 或 `poll resume` 继续 + +如果带了 `--keep-after-error`: +- 收到 `ErrorReport` 后不会暂停 +- 会继续按当前有效轮询间隔自动轮询 + +特殊情况: +- 当最近一次错误是 `No Data Available` 或 `Transport Failure` 时,恢复自动轮询后会优先参考 `retry`,而不是继续只看 `refresh` + +## ErrorReport 展示内容 + +`ErrorReport` 会展示以下内容: +- 错误码及其语义名称 +- encapsulated PDU 长度 +- encapsulated PDU 的 header 摘要 + - PDU 类型 + - version + - length + - field1(按类型解释为 `session_id` 或 `error_code`) +- encapsulated PDU 原始 hex +- arbitrary text 长度 +- arbitrary text 是否是 UTF-8 +- arbitrary text 内容 + +这样在排查协议问题时,不需要先手工拆原始 hex,就能快速知道是哪一个请求触发了错误。 diff --git a/src/bin/rtr_debug_client/main.rs b/src/bin/rtr_debug_client/main.rs index 4f2cee9..3fb6ae3 100644 --- a/src/bin/rtr_debug_client/main.rs +++ b/src/bin/rtr_debug_client/main.rs @@ -1,10 +1,14 @@ use std::env; use std::io; +use std::path::{Path, PathBuf}; +use std::sync::Arc; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use rustls::{ClientConfig as RustlsClientConfig, RootCertStore}; +use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use tokio::io::{self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, WriteHalf}; use tokio::net::TcpStream; use tokio::time::{timeout, Duration, Instant}; +use tokio_rustls::TlsConnector; mod wire; mod pretty; @@ -19,15 +23,23 @@ use crate::protocol::{PduHeader, PduType, QueryMode}; const DEFAULT_READ_TIMEOUT_SECS: u64 = 30; const DEFAULT_POLL_INTERVAL_SECS: u64 = 60; +trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {} +impl AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {} + +type DynStream = Box; +type ClientWriter = WriteHalf; + #[tokio::main] async fn main() -> io::Result<()> { let config = Config::from_args()?; println!("== RTR debug client =="); println!("target : {}", config.addr); + println!("transport: {}", config.transport.describe()); println!("version : {}", config.version); println!("timeout : {}s", config.read_timeout_secs); println!("poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs); + println!("keep-after-error: {}", config.keep_after_error); match &config.mode { QueryMode::Reset => { println!("mode : reset"); @@ -41,14 +53,15 @@ async fn main() -> io::Result<()> { println!(); print_help(); - let stream = TcpStream::connect(&config.addr).await?; + let stream = connect_stream(&config).await?; println!("connected to {}", config.addr); - let (mut reader, mut writer) = stream.into_split(); + let (mut reader, mut writer) = tokio_io::split(stream); let mut state = ClientState::new( config.version, config.read_timeout_secs, config.default_poll_secs, + config.keep_after_error, ); match config.mode { @@ -71,7 +84,7 @@ async fn main() -> io::Result<()> { let mut stdin_lines = BufReader::new(stdin).lines(); loop { - let poll_sleep = tokio::time::sleep_until(state.next_poll_deadline); + let poll_sleep = tokio::time::sleep_until(state.poll_deadline()); tokio::pin!(poll_sleep); tokio::select! { @@ -131,7 +144,7 @@ async fn main() -> io::Result<()> { } async fn handle_incoming_pdu( - writer: &mut OwnedWriteHalf, + writer: &mut ClientWriter, state: &mut ClientState, header: &PduHeader, body: &[u8], @@ -139,9 +152,10 @@ async fn handle_incoming_pdu( match header.pdu_type() { PduType::CacheResponse => { state.current_session_id = Some(header.session_id()); + state.last_error_code = None; } - PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::Aspa => { + PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::RouterKey | PduType::Aspa => { if state.current_session_id.is_none() { state.current_session_id = Some(header.session_id()); } @@ -161,6 +175,7 @@ async fn handle_incoming_pdu( state.refresh = eod.refresh; state.retry = eod.retry; state.expire = eod.expire; + state.last_error_code = None; println!( "updated client state: session_id={}, serial={}", @@ -198,6 +213,11 @@ async fn handle_incoming_pdu( let notify_serial = parse_serial_notify_serial(body); println!(); + println!( + "[notify] received Serial Notify: session_id={}, notify_serial={:?}", + notify_session_id, + notify_serial + ); match (state.session_id, state.serial, notify_serial) { (Some(current_session_id), Some(current_serial), Some(_new_serial)) @@ -233,6 +253,7 @@ async fn handle_incoming_pdu( println!("received Cache Reset, send Reset Query"); state.current_session_id = None; state.serial = None; + state.last_error_code = None; send_reset_query(writer, state.version).await?; state.schedule_next_poll(); println!(); @@ -240,9 +261,20 @@ async fn handle_incoming_pdu( PduType::ErrorReport => { println!(); - println!("received Error Report, keep connection open for debugging."); + println!("received Error Report, pause auto polling for debugging."); + state.last_error_code = Some(header.error_code()); if let Some(retry) = state.retry { - println!("will keep auto polling; server retry hint currently stored: {}s", retry); + println!("server retry hint currently stored: {}s", retry); + if state.should_prefer_retry_poll() { + println!("when resumed, auto polling will use retry instead of refresh."); + } + } + if state.keep_after_error { + println!("keep-after-error is enabled, auto polling will continue."); + state.schedule_next_poll(); + } else { + println!("use `reset`, `serial`, or `poll resume` to continue manually."); + state.pause_auto_poll(); } println!(); } @@ -256,7 +288,7 @@ async fn handle_incoming_pdu( } async fn handle_poll_tick( - writer: &mut OwnedWriteHalf, + writer: &mut ClientWriter, state: &mut ClientState, ) -> io::Result<()> { println!(); @@ -285,7 +317,7 @@ async fn handle_poll_tick( async fn handle_console_command( line: &str, - writer: &mut OwnedWriteHalf, + writer: &mut ClientWriter, state: &mut ClientState, ) -> io::Result { let line = line.trim(); @@ -382,8 +414,24 @@ async fn handle_console_command( "current effective poll interval: {}s", state.effective_poll_secs() ); + println!("poll interval source : {}", state.poll_interval_source()); println!("stored refresh hint : {:?}", state.refresh); println!("default poll interval : {}s", state.default_poll_secs); + println!("last_error_code : {:?}", state.last_error_code); + println!("auto polling paused : {}", state.poll_paused); + } + + ["poll", "pause"] => { + state.pause_auto_poll(); + println!("auto polling paused"); + } + + ["poll", "resume"] => { + state.resume_auto_poll(); + println!( + "auto polling resumed, next poll scheduled after {}s", + state.effective_poll_secs() + ); } ["poll", secs] => { @@ -428,6 +476,9 @@ fn print_help() { println!(" timeout update read timeout seconds"); println!(" poll show current poll interval"); println!(" poll override poll interval seconds"); + println!(" poll pause pause auto polling"); + println!(" poll resume resume auto polling"); + println!(" keep-after-error show current keep-after-error setting"); println!(" quit exit client"); println!(); } @@ -444,6 +495,10 @@ fn print_state(state: &ClientState) { println!(" read_timeout_secs : {}", state.read_timeout_secs); println!(" default_poll_secs : {}", state.default_poll_secs); println!(" effective_poll_secs: {}", state.effective_poll_secs()); + println!(" poll_source : {}", state.poll_interval_source()); + println!(" last_error_code : {:?}", state.last_error_code); + println!(" keep_after_error : {}", state.keep_after_error); + println!(" poll_paused : {}", state.poll_paused); println!(); } @@ -457,14 +512,22 @@ struct ClientState { refresh: Option, retry: Option, expire: Option, + last_error_code: Option, + keep_after_error: bool, read_timeout_secs: u64, default_poll_secs: u64, next_poll_deadline: Instant, + poll_paused: bool, } impl ClientState { - fn new(version: u8, read_timeout_secs: u64, default_poll_secs: u64) -> Self { + fn new( + version: u8, + read_timeout_secs: u64, + default_poll_secs: u64, + keep_after_error: bool, + ) -> Self { Self { version, session_id: None, @@ -473,20 +536,60 @@ impl ClientState { refresh: None, retry: None, expire: None, + last_error_code: None, + keep_after_error, read_timeout_secs, default_poll_secs, next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs), + poll_paused: false, } } fn effective_poll_secs(&self) -> u64 { - self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs) + if self.should_prefer_retry_poll() { + self.retry + .map(|v| v as u64) + .unwrap_or_else(|| self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs)) + } else { + self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs) + } } fn schedule_next_poll(&mut self) { self.next_poll_deadline = Instant::now() + Duration::from_secs(self.effective_poll_secs()); } + + fn pause_auto_poll(&mut self) { + self.poll_paused = true; + } + + fn resume_auto_poll(&mut self) { + self.poll_paused = false; + self.schedule_next_poll(); + } + + fn poll_deadline(&self) -> Instant { + if self.poll_paused { + Instant::now() + Duration::from_secs(365 * 24 * 60 * 60) + } else { + self.next_poll_deadline + } + } + + fn should_prefer_retry_poll(&self) -> bool { + matches!(self.last_error_code, Some(2 | 10)) + } + + fn poll_interval_source(&self) -> &'static str { + if self.should_prefer_retry_poll() && self.retry.is_some() { + "retry" + } else if self.refresh.is_some() { + "refresh" + } else { + "default" + } + } } #[derive(Debug)] @@ -496,17 +599,80 @@ struct Config { mode: QueryMode, read_timeout_secs: u64, default_poll_secs: u64, + transport: TransportConfig, + keep_after_error: bool, } impl Config { fn from_args() -> io::Result { let mut args = env::args().skip(1); + let mut positional = Vec::new(); + let mut transport = TransportConfig::Tcp; + let mut read_timeout_secs = DEFAULT_READ_TIMEOUT_SECS; + let mut default_poll_secs = DEFAULT_POLL_INTERVAL_SECS; + let mut keep_after_error = false; - let addr = args + while let Some(arg) = args.next() { + match arg.as_str() { + "--tls" => { + transport = TransportConfig::Tls(TlsConfig::default()); + } + "--ca-cert" => { + let path = args.next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "--ca-cert requires a path") + })?; + ensure_tls_config(&mut transport)?.ca_cert = Some(PathBuf::from(path)); + } + "--client-cert" => { + let path = args.next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "--client-cert requires a path") + })?; + ensure_tls_config(&mut transport)?.client_cert = Some(PathBuf::from(path)); + } + "--client-key" => { + let path = args.next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "--client-key requires a path") + })?; + ensure_tls_config(&mut transport)?.client_key = Some(PathBuf::from(path)); + } + "--server-name" => { + let name = args.next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "--server-name requires a value") + })?; + ensure_tls_config(&mut transport)?.server_name = Some(name); + } + "--timeout" => { + let secs = args.next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "--timeout requires seconds") + })?; + read_timeout_secs = parse_u64_arg(&secs, "--timeout")?; + } + "--poll" => { + let secs = args.next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "--poll requires seconds") + })?; + default_poll_secs = parse_u64_arg(&secs, "--poll")?; + } + "--keep-after-error" => { + keep_after_error = true; + } + _ if arg.starts_with("--") => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown option '{}'", arg), + )); + } + _ => positional.push(arg), + } + } + + let mut positional = positional.into_iter(); + + let addr = positional .next() .unwrap_or_else(|| "127.0.0.1:3323".to_string()); - let version = args + let version = positional .next() .map(|s| { s.parse::().map_err(|e| { @@ -526,10 +692,10 @@ impl Config { )); } - let mode = match args.next().as_deref() { + let mode = match positional.next().as_deref() { None | Some("reset") => QueryMode::Reset, Some("serial") => { - let session_id = args + let session_id = positional .next() .ok_or_else(|| { io::Error::new( @@ -545,7 +711,7 @@ impl Config { ) })?; - let serial = args + let serial = positional .next() .ok_or_else(|| { io::Error::new( @@ -571,12 +737,212 @@ impl Config { } }; + let transport = finalize_transport(transport, &addr)?; + Ok(Self { addr, version, mode, - read_timeout_secs: DEFAULT_READ_TIMEOUT_SECS, - default_poll_secs: DEFAULT_POLL_INTERVAL_SECS, + read_timeout_secs, + default_poll_secs, + transport, + keep_after_error, }) } -} \ No newline at end of file +} + +#[derive(Debug, Clone)] +enum TransportConfig { + Tcp, + Tls(TlsConfig), +} + +impl TransportConfig { + fn describe(&self) -> String { + match self { + Self::Tcp => "tcp".to_string(), + Self::Tls(cfg) => format!( + "tls (server_name={}, ca_cert={}, client_cert={})", + cfg.server_name.as_deref().unwrap_or(""), + cfg.ca_cert + .as_ref() + .map(|path| path.display().to_string()) + .unwrap_or_else(|| "".to_string()), + cfg.client_cert + .as_ref() + .map(|path| path.display().to_string()) + .unwrap_or_else(|| "".to_string()) + ), + } + } +} + +#[derive(Debug, Clone, Default)] +struct TlsConfig { + server_name: Option, + ca_cert: Option, + client_cert: Option, + client_key: Option, +} + +fn ensure_tls_config(transport: &mut TransportConfig) -> io::Result<&mut TlsConfig> { + if matches!(transport, TransportConfig::Tcp) { + *transport = TransportConfig::Tls(TlsConfig::default()); + } + + match transport { + TransportConfig::Tls(cfg) => Ok(cfg), + TransportConfig::Tcp => unreachable!(), + } +} + +fn finalize_transport(transport: TransportConfig, addr: &str) -> io::Result { + match transport { + TransportConfig::Tcp => Ok(TransportConfig::Tcp), + TransportConfig::Tls(mut cfg) => { + let ca_cert = cfg.ca_cert.take().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "TLS mode requires --ca-cert ", + ) + })?; + + match (&cfg.client_cert, &cfg.client_key) { + (Some(_), Some(_)) | (None, None) => {} + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "TLS client authentication requires both --client-cert and --client-key", + )); + } + } + + let server_name = cfg + .server_name + .take() + .or_else(|| default_server_name_for_addr(addr)) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "TLS mode requires --server-name or an address with a parsable host", + ) + })?; + + Ok(TransportConfig::Tls(TlsConfig { + server_name: Some(server_name), + ca_cert: Some(ca_cert), + client_cert: cfg.client_cert, + client_key: cfg.client_key, + })) + } + } +} + +async fn connect_stream(config: &Config) -> io::Result { + match &config.transport { + TransportConfig::Tcp => Ok(Box::new(TcpStream::connect(&config.addr).await?)), + TransportConfig::Tls(tls) => connect_tls_stream(&config.addr, tls).await, + } +} + +async fn connect_tls_stream(addr: &str, tls: &TlsConfig) -> io::Result { + let stream = TcpStream::connect(addr).await?; + let connector = build_tls_connector(tls)?; + let server_name_str = tls + .server_name + .as_ref() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS server name"))?; + let server_name = ServerName::try_from(server_name_str.clone()).map_err(|err| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("invalid TLS server name '{}': {}", server_name_str, err), + ) + })?; + let tls_stream = connector.connect(server_name, stream).await.map_err(|err| { + io::Error::new(io::ErrorKind::ConnectionAborted, format!("TLS handshake failed: {}", err)) + })?; + Ok(Box::new(tls_stream)) +} + +fn build_tls_connector(tls: &TlsConfig) -> io::Result { + let ca_cert_path = tls + .ca_cert + .as_ref() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS CA cert"))?; + let ca_certs = load_certs(ca_cert_path)?; + let mut roots = RootCertStore::empty(); + let (added, _ignored) = roots.add_parsable_certificates(ca_certs); + if added == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("no valid CA certificates found in {}", ca_cert_path.display()), + )); + } + + let builder = RustlsClientConfig::builder().with_root_certificates(roots); + let client_config = match (&tls.client_cert, &tls.client_key) { + (Some(cert_path), Some(key_path)) => { + let certs = load_certs(cert_path)?; + let key = load_private_key(key_path)?; + builder.with_client_auth_cert(certs, key).map_err(|err| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("invalid TLS client certificate/key: {}", err), + ) + })? + } + (None, None) => builder.with_no_client_auth(), + _ => unreachable!(), + }; + + Ok(TlsConnector::from(Arc::new(client_config))) +} + +fn load_certs(path: &Path) -> io::Result>> { + let mut reader = std::io::BufReader::new(std::fs::File::open(path)?); + let certs = rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + if certs.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("no certificates found in {}", path.display()), + )); + } + Ok(certs) +} + +fn load_private_key(path: &Path) -> io::Result> { + let mut reader = std::io::BufReader::new(std::fs::File::open(path)?); + rustls_pemfile::private_key(&mut reader) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("no private key found in {}", path.display()), + ) + }) +} + +fn default_server_name_for_addr(addr: &str) -> Option { + if let Some(rest) = addr.strip_prefix('[') { + return rest.split(']').next().map(str::to_string); + } + addr.rsplit_once(':').map(|(host, _port)| host.to_string()) +} + +fn parse_u64_arg(value: &str, name: &str) -> io::Result { + let parsed = value.parse::().map_err(|err| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("invalid value for {}: {}", name, err), + ) + })?; + if parsed == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("{} must be > 0", name), + )); + } + Ok(parsed) +} diff --git a/src/bin/rtr_debug_client/pretty.rs b/src/bin/rtr_debug_client/pretty.rs index 3de9202..b44ae49 100644 --- a/src/bin/rtr_debug_client/pretty.rs +++ b/src/bin/rtr_debug_client/pretty.rs @@ -25,6 +25,9 @@ pub fn print_pdu(header: &PduHeader, body: &[u8]) { PduType::Ipv6Prefix => { print_ipv6_prefix(header, body); } + PduType::RouterKey => { + print_router_key(header, body); + } PduType::EndOfData => { print_end_of_data(header, body); } @@ -128,7 +131,11 @@ fn print_end_of_data(header: &PduHeader, body: &[u8]) { } fn print_error_report(header: &PduHeader, body: &[u8]) { - println!("error_code : {}", header.error_code()); + println!( + "error_code : {} ({})", + header.error_code(), + error_code_name(header.error_code()) + ); if body.len() < 8 { println!("invalid ErrorReport body length: {}", body.len()); @@ -165,8 +172,27 @@ fn print_error_report(header: &PduHeader, body: &[u8]) { let text = String::from_utf8_lossy(text_bytes); println!("encap_len : {}", encapsulated_len); + if let Some(encap_header) = parse_encapsulated_header(encapsulated) { + println!("encap_pdu_type : {}", encap_header.pdu_type()); + println!("encap_version : {}", encap_header.version); + println!("encap_length : {}", encap_header.length); + match encap_header.pdu_type() { + PduType::ErrorReport => { + println!("encap_field1 : error_code={}", encap_header.error_code()); + } + PduType::Unknown(_) => { + println!("encap_field1 : {}", encap_header.field1); + } + _ => { + println!("encap_field1 : session_id={}", encap_header.session_id()); + } + } + } else if encapsulated_len > 0 { + println!("encap_header : "); + } println!("encap_pdu : {}", hex_bytes(encapsulated)); println!("text_len : {}", text_len); + println!("text_utf8 : {}", std::str::from_utf8(text_bytes).is_ok()); println!("text : {}", text); } @@ -194,7 +220,6 @@ fn print_serial_query(header: &PduHeader, body: &[u8]) { println!("serial : {}", serial); } -#[allow(dead_code)] fn print_router_key(header: &PduHeader, body: &[u8]) { println!("session_id : {}", header.session_id()); @@ -218,6 +243,34 @@ fn print_router_key(header: &PduHeader, body: &[u8]) { println!("spki : {}", hex_bytes(spki)); } +fn error_code_name(code: u16) -> &'static str { + match code { + 0 => "Corrupt Data", + 1 => "Internal Error", + 2 => "No Data Available", + 3 => "Invalid Request", + 4 => "Unsupported Protocol Version", + 5 => "Unsupported PDU Type", + 6 => "Withdrawal of Unknown Record", + 7 => "Duplicate Announcement Received", + 8 => "Unexpected Protocol Version", + 9 => "ASPA Provider List Error", + 10 => "Transport Failure", + 11 => "Ordering Error", + _ => "Unknown Error Code", + } +} + +fn parse_encapsulated_header(encapsulated: &[u8]) -> Option { + if encapsulated.len() < 8 { + return None; + } + + let mut header = [0u8; 8]; + header.copy_from_slice(&encapsulated[..8]); + Some(PduHeader::from_bytes(header)) +} + fn print_aspa(header: &PduHeader, body: &[u8]) { println!("session_id : {}", header.session_id()); diff --git a/src/bin/rtr_debug_client/protocol.rs b/src/bin/rtr_debug_client/protocol.rs index d2e6e37..eecd881 100644 --- a/src/bin/rtr_debug_client/protocol.rs +++ b/src/bin/rtr_debug_client/protocol.rs @@ -25,6 +25,7 @@ pub enum PduType { CacheResponse, Ipv4Prefix, Ipv6Prefix, + RouterKey, EndOfData, CacheReset, ErrorReport, @@ -41,6 +42,7 @@ impl PduType { Self::CacheResponse => 3, Self::Ipv4Prefix => 4, Self::Ipv6Prefix => 6, + Self::RouterKey => 9, Self::EndOfData => 7, Self::CacheReset => 8, Self::ErrorReport => 10, @@ -57,6 +59,7 @@ impl PduType { Self::CacheResponse => "Cache Response", Self::Ipv4Prefix => "IPv4 Prefix", Self::Ipv6Prefix => "IPv6 Prefix", + Self::RouterKey => "Router Key", Self::EndOfData => "End of Data", Self::CacheReset => "Cache Reset", Self::ErrorReport => "Error Report", @@ -75,6 +78,7 @@ impl From for PduType { 3 => Self::CacheResponse, 4 => Self::Ipv4Prefix, 6 => Self::Ipv6Prefix, + 9 => Self::RouterKey, 7 => Self::EndOfData, 8 => Self::CacheReset, 10 => Self::ErrorReport, @@ -152,4 +156,4 @@ pub fn hex_bytes(data: &[u8]) -> String { let _ = write!(out, "{:02x}", b); } out -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 1eebd79..1b9cf0d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use std::env; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::Duration; @@ -10,7 +11,7 @@ use rpki::rtr::cache::{RtrCache, SharedRtrCache}; use rpki::rtr::loader::load_vrps_from_file; use rpki::rtr::payload::Timing; use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService}; -use rpki::rtr::store_db::RtrStore; +use rpki::rtr::store::RtrStore; #[derive(Debug, Clone)] struct AppConfig { @@ -22,6 +23,7 @@ struct AppConfig { vrp_file: String, tls_cert_path: String, tls_key_path: String, + tls_client_ca_path: String, max_delta: u8, refresh_interval: Duration, @@ -40,23 +42,107 @@ impl Default for AppConfig { vrp_file: r"C:\Users\xuxiu\git_code\rpki\data\vrps.txt".to_string(), tls_cert_path: "./certs/server.crt".to_string(), tls_key_path: "./certs/server.key".to_string(), + tls_client_ca_path: "./certs/client-ca.crt".to_string(), max_delta: 100, - refresh_interval: Duration::from_secs(300), + refresh_interval: Duration::from_secs(10), service_config: RtrServiceConfig { max_connections: 512, notify_queue_size: 1024, + tcp_keepalive: Some(Duration::from_secs(60)), + warn_insecure_tcp: true, + require_tls_server_dns_name_san: false, }, } } } +impl AppConfig { + fn from_env() -> Result { + let mut config = Self::default(); + + if let Some(value) = env_var("RPKI_RTR_ENABLE_TLS")? { + config.enable_tls = parse_bool(&value, "RPKI_RTR_ENABLE_TLS")?; + } + if let Some(value) = env_var("RPKI_RTR_TCP_ADDR")? { + config.tcp_addr = value + .parse() + .map_err(|err| anyhow!("invalid RPKI_RTR_TCP_ADDR '{}': {}", value, err))?; + } + if let Some(value) = env_var("RPKI_RTR_TLS_ADDR")? { + config.tls_addr = value + .parse() + .map_err(|err| anyhow!("invalid RPKI_RTR_TLS_ADDR '{}': {}", value, err))?; + } + if let Some(value) = env_var("RPKI_RTR_DB_PATH")? { + config.db_path = value; + } + if let Some(value) = env_var("RPKI_RTR_VRP_FILE")? { + config.vrp_file = value; + } + if let Some(value) = env_var("RPKI_RTR_TLS_CERT_PATH")? { + config.tls_cert_path = value; + } + if let Some(value) = env_var("RPKI_RTR_TLS_KEY_PATH")? { + config.tls_key_path = value; + } + if let Some(value) = env_var("RPKI_RTR_TLS_CLIENT_CA_PATH")? { + config.tls_client_ca_path = value; + } + if let Some(value) = env_var("RPKI_RTR_MAX_DELTA")? { + config.max_delta = value + .parse() + .map_err(|err| anyhow!("invalid RPKI_RTR_MAX_DELTA '{}': {}", value, err))?; + } + if let Some(value) = env_var("RPKI_RTR_REFRESH_INTERVAL_SECS")? { + let secs: u64 = value.parse().map_err(|err| { + anyhow!( + "invalid RPKI_RTR_REFRESH_INTERVAL_SECS '{}': {}", + value, + err + ) + })?; + config.refresh_interval = Duration::from_secs(secs); + } + if let Some(value) = env_var("RPKI_RTR_MAX_CONNECTIONS")? { + config.service_config.max_connections = value.parse().map_err(|err| { + anyhow!("invalid RPKI_RTR_MAX_CONNECTIONS '{}': {}", value, err) + })?; + } + if let Some(value) = env_var("RPKI_RTR_NOTIFY_QUEUE_SIZE")? { + config.service_config.notify_queue_size = value.parse().map_err(|err| { + anyhow!("invalid RPKI_RTR_NOTIFY_QUEUE_SIZE '{}': {}", value, err) + })?; + } + if let Some(value) = env_var("RPKI_RTR_TCP_KEEPALIVE_SECS")? { + let secs: u64 = value.parse().map_err(|err| { + anyhow!("invalid RPKI_RTR_TCP_KEEPALIVE_SECS '{}': {}", value, err) + })?; + config.service_config.tcp_keepalive = if secs == 0 { + None + } else { + Some(Duration::from_secs(secs)) + }; + } + if let Some(value) = env_var("RPKI_RTR_WARN_INSECURE_TCP")? { + config.service_config.warn_insecure_tcp = + parse_bool(&value, "RPKI_RTR_WARN_INSECURE_TCP")?; + } + if let Some(value) = env_var("RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN")? { + config.service_config.require_tls_server_dns_name_san = + parse_bool(&value, "RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN")?; + } + + Ok(config) + } +} + #[tokio::main] async fn main() -> Result<()> { init_tracing(); - let config = AppConfig::default(); + let config = AppConfig::from_env()?; log_startup_config(&config); let store = open_store(&config)?; @@ -101,8 +187,8 @@ fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result RunningRtrService config.tls_addr, &config.tls_cert_path, &config.tls_key_path, + &config.tls_client_ca_path, ) } else { info!("starting TCP RTR server"); @@ -142,6 +229,7 @@ fn spawn_refresh_task( match load_vrps_from_file(&vrp_file) { Ok(payloads) => { + let payload_count = payloads.len(); let updated = { let mut cache = match shared_cache.write() { Ok(guard) => guard, @@ -154,7 +242,27 @@ fn spawn_refresh_task( let old_serial = cache.serial(); match cache.update(payloads, &store) { - Ok(()) => cache.serial() != old_serial, + Ok(()) => { + let new_serial = cache.serial(); + if new_serial != old_serial { + info!( + "RTR cache refresh applied: vrp_file={}, payload_count={}, old_serial={}, new_serial={}", + vrp_file, + payload_count, + old_serial, + new_serial + ); + true + } else { + info!( + "RTR cache refresh found no change: vrp_file={}, payload_count={}, serial={}", + vrp_file, + payload_count, + old_serial + ); + false + } + } Err(err) => { warn!("RTR cache update failed: {:?}", err); false @@ -191,6 +299,7 @@ fn log_startup_config(config: &AppConfig) { info!("tls_addr={}", config.tls_addr); info!("tls_cert_path={}", config.tls_cert_path); info!("tls_key_path={}", config.tls_key_path); + info!("tls_client_ca_path={}", config.tls_client_ca_path); } info!("vrp_file={}", config.vrp_file); @@ -207,6 +316,22 @@ fn log_startup_config(config: &AppConfig) { "notify_queue_size={}", config.service_config.notify_queue_size ); + info!( + "tcp_keepalive_secs={}", + config + .service_config + .tcp_keepalive + .map(|duration| duration.as_secs().to_string()) + .unwrap_or_else(|| "disabled".to_string()) + ); + info!( + "warn_insecure_tcp={}", + config.service_config.warn_insecure_tcp + ); + info!( + "require_tls_server_dns_name_san={}", + config.service_config.require_tls_server_dns_name_san + ); } fn init_tracing() { @@ -215,4 +340,20 @@ fn init_tracing() { .with_thread_ids(true) .with_level(true) .try_init(); -} \ No newline at end of file +} + +fn env_var(name: &str) -> Result> { + match env::var(name) { + Ok(value) => Ok(Some(value)), + Err(env::VarError::NotPresent) => Ok(None), + Err(err) => Err(anyhow!("failed to read {}: {}", name, err)), + } +} + +fn parse_bool(value: &str, name: &str) -> Result { + match value.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" => Ok(true), + "0" | "false" | "no" | "off" => Ok(false), + _ => Err(anyhow!("invalid {} '{}': expected boolean", name, value)), + } +} diff --git a/src/rtr/cache.rs b/src/rtr/cache.rs deleted file mode 100644 index 7d2b115..0000000 --- a/src/rtr/cache.rs +++ /dev/null @@ -1,929 +0,0 @@ -use std::cmp::Ordering; -use std::collections::{BTreeSet, VecDeque}; -use std::sync::{Arc, RwLock}; -use std::time::{Duration, Instant}; - -use chrono::{DateTime, NaiveDateTime, Utc}; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; -use crate::data_model::resources::ip_resources::IPAddress; -use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Timing}; -use crate::rtr::store_db::RtrStore; - -const DEFAULT_RETRY_INTERVAL: Duration = Duration::from_secs(600); -const DEFAULT_EXPIRE_INTERVAL: Duration = Duration::from_secs(7200); - - - -pub type SharedRtrCache = Arc>; - -#[derive(Debug, Clone)] -pub struct DualTime { - instant: Instant, - utc: DateTime, -} - -impl DualTime { - /// Create current time. - pub fn now() -> Self { - Self { - instant: Instant::now(), - utc: Utc::now(), - } - } - - /// Get UTC time for logs. - pub fn utc(&self) -> DateTime { - self.utc - } - - /// Elapsed duration since creation/reset. - pub fn elapsed(&self) -> Duration { - self.instant.elapsed() - } - - /// Whether duration is expired. - pub fn is_expired(&self, duration: Duration) -> bool { - self.elapsed() >= duration - } - - /// Reset to now. - pub fn reset(&mut self) { - self.instant = Instant::now(); - self.utc = Utc::now(); - } -} - -impl Serialize for DualTime { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.utc.timestamp_millis().serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for DualTime { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let millis = i64::deserialize(deserializer)?; - let naive = NaiveDateTime::from_timestamp_millis(millis) - .ok_or_else(|| serde::de::Error::custom("invalid timestamp"))?; - let utc = DateTime::::from_utc(naive, Utc); - Ok(Self { - instant: Instant::now(), - utc, - }) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Snapshot { - origins: BTreeSet, - router_keys: BTreeSet, - aspas: BTreeSet, - created_at: DualTime, - - origins_hash: [u8; 32], - router_keys_hash: [u8; 32], - aspas_hash: [u8; 32], - snapshot_hash: [u8; 32], -} - -impl Snapshot { - pub fn new( - origins: BTreeSet, - router_keys: BTreeSet, - aspas: BTreeSet, - ) -> Self { - let mut snapshot = Snapshot { - origins, - router_keys, - aspas, - created_at: DualTime::now(), - origins_hash: [0u8; 32], - router_keys_hash: [0u8; 32], - aspas_hash: [0u8; 32], - snapshot_hash: [0u8; 32], - }; - snapshot.recompute_hashes(); - snapshot - } - - pub fn empty() -> Self { - Self::new(BTreeSet::new(), BTreeSet::new(), BTreeSet::new()) - } - - pub fn from_payloads(payloads: Vec) -> Self { - let mut origins = BTreeSet::new(); - let mut router_keys = BTreeSet::new(); - let mut aspas = BTreeSet::new(); - - for p in payloads { - match p { - Payload::RouteOrigin(o) => { - origins.insert(o); - } - Payload::RouterKey(k) => { - router_keys.insert(k); - } - Payload::Aspa(a) => { - aspas.insert(a); - } - } - } - - Snapshot::new(origins, router_keys, aspas) - } - - pub fn recompute_hashes(&mut self) { - self.origins_hash = self.compute_origins_hash(); - self.router_keys_hash = self.compute_router_keys_hash(); - self.aspas_hash = self.compute_aspas_hash(); - self.snapshot_hash = self.compute_snapshot_hash(); - } - - fn compute_origins_hash(&self) -> [u8; 32] { - Self::hash_ordered_iter(self.origins.iter()) - } - - fn compute_router_keys_hash(&self) -> [u8; 32] { - Self::hash_ordered_iter(self.router_keys.iter()) - } - - fn compute_aspas_hash(&self) -> [u8; 32] { - Self::hash_ordered_iter(self.aspas.iter()) - } - - fn compute_snapshot_hash(&self) -> [u8; 32] { - let mut hasher = Sha256::new(); - hasher.update(b"snapshot:v1"); - hasher.update(self.origins_hash); - hasher.update(self.router_keys_hash); - hasher.update(self.aspas_hash); - hasher.finalize().into() - } - - fn hash_ordered_iter<'a, T, I>(iter: I) -> [u8; 32] - where - T: Serialize + 'a, - I: IntoIterator, - { - let mut hasher = Sha256::new(); - hasher.update(b"set:v1"); - - for item in iter { - let encoded = - serde_json::to_vec(item).expect("serialize snapshot item for hashing failed"); - let len = (encoded.len() as u32).to_be_bytes(); - hasher.update(len); - hasher.update(encoded); - } - - hasher.finalize().into() - } - - pub fn diff(&self, new_snapshot: &Snapshot) -> (Vec, Vec) { - let mut announced = Vec::new(); - let mut withdrawn = Vec::new(); - - if !self.same_origins(new_snapshot) { - for origin in new_snapshot.origins.difference(&self.origins) { - announced.push(Payload::RouteOrigin(origin.clone())); - } - for origin in self.origins.difference(&new_snapshot.origins) { - withdrawn.push(Payload::RouteOrigin(origin.clone())); - } - } - - if !self.same_router_keys(new_snapshot) { - for key in new_snapshot.router_keys.difference(&self.router_keys) { - announced.push(Payload::RouterKey(key.clone())); - } - for key in self.router_keys.difference(&new_snapshot.router_keys) { - withdrawn.push(Payload::RouterKey(key.clone())); - } - } - - if !self.same_aspas(new_snapshot) { - for aspa in new_snapshot.aspas.difference(&self.aspas) { - announced.push(Payload::Aspa(aspa.clone())); - } - for aspa in self.aspas.difference(&new_snapshot.aspas) { - withdrawn.push(Payload::Aspa(aspa.clone())); - } - } - - (announced, withdrawn) - } - - pub fn created_at(&self) -> DualTime { - self.created_at.clone() - } - - pub fn payloads(&self) -> Vec { - let mut v = Vec::with_capacity( - self.origins.len() + self.router_keys.len() + self.aspas.len(), - ); - - v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin)); - v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey)); - v.extend(self.aspas.iter().cloned().map(Payload::Aspa)); - - v - } - - /// Payloads sorted for RTR full snapshot sending. - /// Snapshot represents current valid state, so all payloads are treated as announcements. - pub fn payloads_for_rtr(&self) -> Vec { - let mut payloads = self.payloads(); - sort_payloads_for_rtr(&mut payloads, true); - payloads - } - - pub fn origins_hash(&self) -> [u8; 32] { - self.origins_hash - } - - pub fn router_keys_hash(&self) -> [u8; 32] { - self.router_keys_hash - } - - pub fn aspas_hash(&self) -> [u8; 32] { - self.aspas_hash - } - - pub fn snapshot_hash(&self) -> [u8; 32] { - self.snapshot_hash - } - - pub fn same_origins(&self, other: &Self) -> bool { - self.origins_hash == other.origins_hash - } - - pub fn same_router_keys(&self, other: &Self) -> bool { - self.router_keys_hash == other.router_keys_hash - } - - pub fn same_aspas(&self, other: &Self) -> bool { - self.aspas_hash == other.aspas_hash - } - - pub fn same_content(&self, other: &Self) -> bool { - self.snapshot_hash == other.snapshot_hash - } - - pub fn origins(&self) -> &BTreeSet { - &self.origins - } - - pub fn router_keys(&self) -> &BTreeSet { - &self.router_keys - } - - pub fn aspas(&self) -> &BTreeSet { - &self.aspas - } -} - -impl Snapshot { - pub fn is_empty(&self) -> bool { - self.origins.is_empty() - && self.router_keys.is_empty() - && self.aspas.is_empty() - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Delta { - serial: u32, - announced: Vec, - withdrawn: Vec, - created_at: DualTime, -} - -impl Delta { - pub fn new(serial: u32, mut announced: Vec, mut withdrawn: Vec) -> Self { - sort_payloads_for_rtr(&mut announced, true); - sort_payloads_for_rtr(&mut withdrawn, false); - - Delta { - serial, - announced, - withdrawn, - created_at: DualTime::now(), - } - } - - pub fn serial(&self) -> u32 { - self.serial - } - - pub fn announced(&self) -> &[Payload] { - &self.announced - } - - pub fn withdrawn(&self) -> &[Payload] { - &self.withdrawn - } - - pub fn created_at(&self) -> DualTime { - self.created_at.clone() - } -} - -#[derive(Debug)] -pub struct RtrCache { - // Session ID created at cache startup. - session_id: u16, - // Current serial. - pub serial: u32, - // Full snapshot. - pub snapshot: Snapshot, - // Delta window. - deltas: VecDeque>, - // Max number of deltas to keep. - max_delta: u8, - // Refresh interval. - timing: Timing, - // Last update begin time. - last_update_begin: DualTime, - // Last update end time. - last_update_end: DualTime, - // Cache created time. - created_at: DualTime, -} - -impl Default for RtrCache { - fn default() -> Self { - let now = DualTime::now(); - Self { - session_id: rand::random(), - serial: 0, - snapshot: Snapshot::empty(), - deltas: VecDeque::with_capacity(100), - max_delta: 100, - timing: Timing::default(), - last_update_begin: now.clone(), - last_update_end: now.clone(), - created_at: now, - } - } -} - -pub struct RtrCacheBuilder { - session_id: Option, - max_delta: Option, - timing: Option, - serial: Option, - snapshot: Option, - deltas: Option>>, - created_at: Option, -} - -impl RtrCacheBuilder { - pub fn new() -> Self { - Self { - session_id: None, - max_delta: None, - timing: None, - serial: None, - snapshot: None, - deltas: None, - created_at: None, - } - } - - pub fn session_id(mut self, v: u16) -> Self { - self.session_id = Some(v); - self - } - - pub fn max_delta(mut self, v: u8) -> Self { - self.max_delta = Some(v); - self - } - - pub fn timing(mut self, v: Timing) -> Self { - self.timing = Some(v); - self - } - - pub fn serial(mut self, v: u32) -> Self { - self.serial = Some(v); - self - } - - pub fn snapshot(mut self, v: Snapshot) -> Self { - self.snapshot = Some(v); - self - } - - pub fn deltas(mut self, v: VecDeque>) -> Self { - self.deltas = Some(v); - self - } - - pub fn created_at(mut self, v: DualTime) -> Self { - self.created_at = Some(v); - self - } - - pub fn build(self) -> RtrCache { - let now = DualTime::now(); - let max_delta = self.max_delta.unwrap_or(100); - let timing = self.timing.unwrap_or_default(); - let snapshot = self.snapshot.unwrap_or_else(Snapshot::empty); - let deltas = self - .deltas - .unwrap_or_else(|| VecDeque::with_capacity(max_delta.into())); - - let serial = self.serial.unwrap_or(0); - let created_at = self.created_at.unwrap_or_else(|| now.clone()); - let session_id = self.session_id.unwrap_or_else(rand::random); - - RtrCache { - session_id, - serial, - snapshot, - deltas, - max_delta, - timing, - last_update_begin: now.clone(), - last_update_end: now, - created_at, - } - } -} - -impl RtrCache { - /// Initialize cache from DB if possible; otherwise from file loader. - pub fn init( - self, - store: &RtrStore, - max_delta: u8, - timing: Timing, - file_loader: impl Fn() -> anyhow::Result>, - ) -> anyhow::Result { - if let Some(cache) = Self::try_restore_from_store(store, max_delta, timing)? { - tracing::info!( - "RTR cache restored from store: session_id={}, serial={}", - self.session_id, - self.serial - ); - return Ok(cache); - } - - tracing::warn!("RTR cache store unavailable or invalid, fallback to file loader"); - - let payloads = file_loader()?; - let snapshot = Snapshot::from_payloads(payloads); - - if snapshot.is_empty() { - anyhow::bail!("file loader returned an empty snapshot"); - } - - tracing::info!( - "RTR cache initialized from file loader: session_id={}, serial={}", - self.session_id, - self.serial - ); - - let serial = 1; - let session_id: u16 = rand::random(); - - let snapshot_for_store = snapshot.clone(); - let snapshot_for_cache = snapshot.clone(); - let store = store.clone(); - - tokio::spawn(async move { - if let Err(e) = - store.save_snapshot_and_meta(&snapshot_for_store, session_id, serial) - { - tracing::error!("persist failed: {:?}", e); - } - }); - - Ok(RtrCacheBuilder::new() - .session_id(session_id) - .max_delta(max_delta) - .timing(timing) - .serial(serial) - .snapshot(snapshot_for_cache) - .build()) - } - - fn try_restore_from_store( - store: &RtrStore, - max_delta: u8, - timing: Timing, - ) -> anyhow::Result> { - let snapshot = store.get_snapshot()?; - let session_id = store.get_session_id()?; - let serial = store.get_serial()?; - - let (snapshot, session_id, serial) = match (snapshot, session_id, serial) { - (Some(snapshot), Some(session_id), Some(serial)) => (snapshot, session_id, serial), - _ => { - tracing::warn!("RTR cache store incomplete: snapshot/session_id/serial missing"); - return Ok(None); - } - }; - - if snapshot.is_empty() { - tracing::warn!("RTR cache store snapshot is empty, treat as unusable"); - return Ok(None); - } - - let mut cache = RtrCacheBuilder::new() - .session_id(session_id) - .max_delta(max_delta) - .timing(timing) - .serial(serial) - .snapshot(snapshot) - .build(); - - match store.get_delta_window()? { - Some((min_serial, _max_serial)) => { - let deltas = match store.load_deltas_since(min_serial.wrapping_sub(1)) { - Ok(deltas) => deltas, - Err(err) => { - tracing::warn!( - "RTR cache store delta recovery failed, treat store as unusable: {:?}", - err - ); - return Ok(None); - } - }; - - for delta in deltas { - cache.push_delta(Arc::new(delta)); - } - } - None => { - tracing::info!("RTR cache store has no delta window, restore snapshot only"); - } - } - - Ok(Some(cache)) - } - - fn next_serial(&mut self) -> u32 { - self.serial = self.serial.wrapping_add(1); - self.serial - } - - fn push_delta(&mut self, delta: Arc) { - if self.deltas.len() >= self.max_delta as usize { - self.deltas.pop_front(); - } - self.deltas.push_back(delta); - } - - fn replace_snapshot(&mut self, snapshot: Snapshot) { - self.snapshot = snapshot; - } - - fn delta_window(&self) -> Option<(u32, u32)> { - let min = self.deltas.front().map(|d| d.serial()); - let max = self.deltas.back().map(|d| d.serial()); - match (min, max) { - (Some(min), Some(max)) => Some((min, max)), - _ => None, - } - } - - fn store_sync( - &mut self, - store: &RtrStore, - snapshot: Snapshot, - serial: u32, - session_id: u16, - delta: Arc, - ) { - let window = self.delta_window(); - let store = store.clone(); - - tokio::spawn(async move { - if let Err(e) = store.save_delta(&delta) { - tracing::error!("persist delta failed: {:?}", e); - } - if let Err(e) = store.save_snapshot_and_meta(&snapshot, session_id, serial) { - tracing::error!("persist snapshot/meta failed: {:?}", e); - } - if let Some((min_serial, max_serial)) = window { - if let Err(e) = store.set_delta_window(min_serial, max_serial) { - tracing::error!("persist delta window failed: {:?}", e); - } - } - }); - } - - // Update cache. - pub fn update( - &mut self, - new_payloads: Vec, - store: &RtrStore, - ) -> anyhow::Result<()> { - self.last_update_begin = DualTime::now(); - - let new_snapshot = Snapshot::from_payloads(new_payloads); - - if self.snapshot.same_content(&new_snapshot) { - self.last_update_end = DualTime::now(); - return Ok(()); - } - - let (announced, withdrawn) = self.snapshot.diff(&new_snapshot); - - if announced.is_empty() && withdrawn.is_empty() { - self.last_update_end = DualTime::now(); - return Ok(()); - } - - let new_serial = self.next_serial(); - let delta = Arc::new(Delta::new(new_serial, announced, withdrawn)); - - self.push_delta(delta.clone()); - self.replace_snapshot(new_snapshot.clone()); - self.last_update_end = DualTime::now(); - - self.store_sync(store, new_snapshot, new_serial, self.session_id, delta); - - Ok(()) - } - - pub fn session_id(&self) -> u16 { - self.session_id - } - - pub fn snapshot(&self) -> Snapshot { - self.snapshot.clone() - } - - pub fn serial(&self) -> u32 { - self.serial - } - - pub fn timing(&self) -> Timing { - self.timing - } - - pub fn retry_interval(&self) -> Duration { - DEFAULT_RETRY_INTERVAL - } - - pub fn expire_interval(&self) -> Duration { - DEFAULT_EXPIRE_INTERVAL - } - - pub fn current_snapshot(&self) -> (&Snapshot, u32, u16) { - (&self.snapshot, self.serial, self.session_id) - } - - pub fn last_update_begin(&self) -> DualTime { - self.last_update_begin.clone() - } - - pub fn last_update_end(&self) -> DualTime { - self.last_update_end.clone() - } - - pub fn created_at(&self) -> DualTime { - self.created_at.clone() - } -} - -impl RtrCache { - pub fn get_deltas_since( - &self, - client_session: u16, - client_serial: u32, - ) -> SerialResult { - if client_session != self.session_id { - return SerialResult::ResetRequired; - } - - if client_serial == self.serial { - return SerialResult::UpToDate; - } - - if self.deltas.is_empty() { - return SerialResult::ResetRequired; - } - - let oldest_serial = self.deltas.front().unwrap().serial; - let newest_serial = self.deltas.back().unwrap().serial; - - let min_supported = oldest_serial.wrapping_sub(1); - if client_serial < min_supported { - return SerialResult::ResetRequired; - } - - if client_serial > self.serial { - return SerialResult::ResetRequired; - } - - let mut result = Vec::new(); - for delta in &self.deltas { - if delta.serial > client_serial { - result.push(delta.clone()); - } - } - - if let Some(first) = result.first() { - if first.serial != client_serial.wrapping_add(1) { - return SerialResult::ResetRequired; - } - } else { - return SerialResult::UpToDate; - } - - let _ = newest_serial; - SerialResult::Deltas(result) - } -} - -pub enum SerialResult { - /// Client is up to date. - UpToDate, - /// Return applicable deltas. - Deltas(Vec>), - /// Delta window cannot cover; reset required. - ResetRequired, -} - -//------------ RTR ordering ------------------------------------------------- - -#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] -enum PayloadPduType { - Ipv4Prefix = 4, - Ipv6Prefix = 6, - RouterKey = 9, - Aspa = 11, -} - -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -enum RouteOriginKey { - V4 { - addr: u32, - plen: u8, - mlen: u8, - asn: u32, - }, - V6 { - addr: u128, - plen: u8, - mlen: u8, - asn: u32, - }, -} - -fn sort_payloads_for_rtr(payloads: &mut [Payload], announce: bool) { - payloads.sort_by(|a, b| compare_payload_for_rtr(a, b, announce)); -} - -fn compare_payload_for_rtr(a: &Payload, b: &Payload, announce: bool) -> Ordering { - let type_a = payload_pdu_type(a); - let type_b = payload_pdu_type(b); - - match type_a.cmp(&type_b) { - Ordering::Equal => {} - other => return other, - } - - match (a, b) { - (Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => { - compare_route_origin_for_rtr(a, b, announce) - } - (Payload::RouterKey(a), Payload::RouterKey(b)) => { - compare_router_key_for_rtr(a, b) - } - (Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b), - _ => Ordering::Equal, - } -} - -fn payload_pdu_type(payload: &Payload) -> PayloadPduType { - match payload { - Payload::RouteOrigin(ro) => { - if route_origin_is_ipv4(ro) { - PayloadPduType::Ipv4Prefix - } else { - PayloadPduType::Ipv6Prefix - } - } - Payload::RouterKey(_) => PayloadPduType::RouterKey, - Payload::Aspa(_) => PayloadPduType::Aspa, - } -} - -fn route_origin_is_ipv4(ro: &RouteOrigin) -> bool { - ro.prefix().address.is_ipv4() -} - -fn route_origin_key(ro: &RouteOrigin) -> RouteOriginKey { - let prefix = ro.prefix(); - let plen = prefix.prefix_length; - let mlen = ro.max_length(); - let asn = ro.asn().into_u32(); - - match prefix.address { - IPAddress::V4(addr) => { - RouteOriginKey::V4 { - addr: u32::from(addr), - plen, - mlen, - asn, - } - } - IPAddress::V6(addr) => { - RouteOriginKey::V6 { - addr: u128::from(addr), - plen, - mlen, - asn, - } - } - } -} - -fn compare_route_origin_for_rtr( - a: &RouteOrigin, - b: &RouteOrigin, - announce: bool, -) -> Ordering { - match (route_origin_key(a), route_origin_key(b)) { - ( - RouteOriginKey::V4 { - addr: addr_a, - plen: plen_a, - mlen: mlen_a, - asn: asn_a, - }, - RouteOriginKey::V4 { - addr: addr_b, - plen: plen_b, - mlen: mlen_b, - asn: asn_b, - }, - ) => { - if announce { - addr_b.cmp(&addr_a) - .then_with(|| mlen_b.cmp(&mlen_a)) - .then_with(|| plen_b.cmp(&plen_a)) - .then_with(|| asn_b.cmp(&asn_a)) - } else { - addr_a.cmp(&addr_b) - .then_with(|| mlen_a.cmp(&mlen_b)) - .then_with(|| plen_a.cmp(&plen_b)) - .then_with(|| asn_a.cmp(&asn_b)) - } - } - - ( - RouteOriginKey::V6 { - addr: addr_a, - plen: plen_a, - mlen: mlen_a, - asn: asn_a, - }, - RouteOriginKey::V6 { - addr: addr_b, - plen: plen_b, - mlen: mlen_b, - asn: asn_b, - }, - ) => { - if announce { - addr_b.cmp(&addr_a) - .then_with(|| mlen_b.cmp(&mlen_a)) - .then_with(|| plen_b.cmp(&plen_a)) - .then_with(|| asn_b.cmp(&asn_a)) - } else { - addr_a.cmp(&addr_b) - .then_with(|| mlen_a.cmp(&mlen_b)) - .then_with(|| plen_a.cmp(&plen_b)) - .then_with(|| asn_a.cmp(&asn_b)) - } - } - - _ => Ordering::Equal, - } -} - -fn compare_router_key_for_rtr(a: &RouterKey, b: &RouterKey) -> Ordering { - a.ski() - .cmp(&b.ski()) - .then_with(|| a.spki().len().cmp(&b.spki().len())) - .then_with(|| a.spki().cmp(b.spki())) - .then_with(|| a.asn().into_u32().cmp(&b.asn().into_u32())) -} - -fn compare_aspa_for_rtr(a: &Aspa, b: &Aspa) -> Ordering { - a.customer_asn() - .into_u32() - .cmp(&b.customer_asn().into_u32()) -} \ No newline at end of file diff --git a/src/rtr/cache/core.rs b/src/rtr/cache/core.rs new file mode 100644 index 0000000..b056ab0 --- /dev/null +++ b/src/rtr/cache/core.rs @@ -0,0 +1,632 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::cmp::Ordering; +use std::sync::Arc; +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use tracing::{debug, info, warn}; + +use crate::rtr::payload::{Payload, Timing}; + +use super::model::{Delta, DualTime, Snapshot}; +use super::ordering::{change_key, ChangeKey}; + +const SERIAL_HALF_RANGE: u32 = 1 << 31; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq)] +pub enum CacheAvailability { + Ready, + NoDataAvailable, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct SessionIds { + ids: [u16; 3], +} + +impl SessionIds { + pub fn from_array(ids: [u16; 3]) -> Self { + Self { ids } + } + + pub fn random_distinct() -> Self { + let mut ids = [0u16; 3]; + for idx in 0..ids.len() { + loop { + let candidate: u16 = rand::random(); + if ids[..idx].iter().all(|existing| *existing != candidate) { + ids[idx] = candidate; + break; + } + } + } + Self { ids } + } + + pub fn get(&self, version: u8) -> u16 { + self.ids[version_index(version)] + } +} + +#[derive(Debug)] +pub struct RtrCache { + availability: CacheAvailability, + session_ids: SessionIds, + serial: u32, + snapshot: Snapshot, + deltas: VecDeque>, + max_delta: u8, + timing: Timing, + last_update_begin: DualTime, + last_update_end: DualTime, + created_at: DualTime, +} + +impl Default for RtrCache { + fn default() -> Self { + let now = DualTime::now(); + Self { + availability: CacheAvailability::Ready, + session_ids: SessionIds::random_distinct(), + serial: 0, + snapshot: Snapshot::empty(), + deltas: VecDeque::with_capacity(100), + max_delta: 100, + timing: Timing::default(), + last_update_begin: now.clone(), + last_update_end: now.clone(), + created_at: now, + } + } +} + +pub struct RtrCacheBuilder { + availability: Option, + session_ids: Option, + max_delta: Option, + timing: Option, + serial: Option, + snapshot: Option, + deltas: Option>>, + created_at: Option, +} + +impl RtrCacheBuilder { + pub fn new() -> Self { + Self { + availability: None, + session_ids: None, + max_delta: None, + timing: None, + serial: None, + snapshot: None, + deltas: None, + created_at: None, + } + } + + pub fn session_ids(mut self, v: SessionIds) -> Self { + self.session_ids = Some(v); + self + } + + pub fn availability(mut self, v: CacheAvailability) -> Self { + self.availability = Some(v); + self + } + + pub fn max_delta(mut self, v: u8) -> Self { + self.max_delta = Some(v); + self + } + + pub fn timing(mut self, v: Timing) -> Self { + self.timing = Some(v); + self + } + + pub fn serial(mut self, v: u32) -> Self { + self.serial = Some(v); + self + } + + pub fn snapshot(mut self, v: Snapshot) -> Self { + self.snapshot = Some(v); + self + } + + pub fn deltas(mut self, v: VecDeque>) -> Self { + self.deltas = Some(v); + self + } + + pub fn created_at(mut self, v: DualTime) -> Self { + self.created_at = Some(v); + self + } + + pub fn build(self) -> RtrCache { + let now = DualTime::now(); + let max_delta = self.max_delta.unwrap_or(100); + let timing = self.timing.unwrap_or_default(); + let snapshot = self.snapshot.unwrap_or_else(Snapshot::empty); + let deltas = self + .deltas + .unwrap_or_else(|| VecDeque::with_capacity(max_delta.into())); + + let serial = self.serial.unwrap_or(0); + let created_at = self.created_at.unwrap_or_else(|| now.clone()); + let availability = self.availability.unwrap_or(CacheAvailability::Ready); + let session_ids = self + .session_ids + .unwrap_or_else(SessionIds::random_distinct); + + RtrCache { + availability, + session_ids, + serial, + snapshot, + deltas, + max_delta, + timing, + last_update_begin: now.clone(), + last_update_end: now, + created_at, + } + } +} + +impl RtrCache { + fn set_unavailable(&mut self) { + warn!( + "RTR cache entering NoDataAvailable: old_serial={}, snapshot_empty={}, delta_count={}", + self.serial, + self.snapshot.is_empty(), + self.deltas.len() + ); + self.availability = CacheAvailability::NoDataAvailable; + self.snapshot = Snapshot::empty(); + self.deltas.clear(); + } + + fn reinitialize_from_snapshot(&mut self, snapshot: Snapshot) -> AppliedUpdate { + let old_serial = self.serial; + let old_session_ids = self.session_ids.clone(); + self.availability = CacheAvailability::Ready; + self.session_ids = SessionIds::random_distinct(); + self.serial = 1; + self.snapshot = snapshot.clone(); + self.deltas.clear(); + self.last_update_end = DualTime::now(); + info!( + "RTR cache reinitialized from usable snapshot: old_serial={}, new_serial={}, old_session_ids={:?}, new_session_ids={:?}, payloads(route_origins={}, router_keys={}, aspas={})", + old_serial, + self.serial, + old_session_ids, + self.session_ids, + snapshot.origins().len(), + snapshot.router_keys().len(), + snapshot.aspas().len() + ); + + AppliedUpdate { + availability: self.availability, + snapshot, + serial: self.serial, + session_ids: self.session_ids.clone(), + delta: None, + delta_window: None, + clear_delta_window: true, + } + } + + fn next_serial(&mut self) -> u32 { + let old = self.serial; + self.serial = self.serial.wrapping_add(1); + debug!( + "RTR cache advanced serial: old_serial={}, new_serial={}", + old, + self.serial + ); + self.serial + } + + fn push_delta(&mut self, delta: Arc) { + let dropped_serial = if self.deltas.len() >= self.max_delta as usize { + self.deltas.front().map(|oldest| oldest.serial()) + } else { + None + }; + if self.deltas.len() >= self.max_delta as usize { + self.deltas.pop_front(); + } + debug!( + "RTR cache pushing delta into window: delta_serial={}, announced={}, withdrawn={}, dropped_oldest_serial={:?}, window_size_before={}, max_delta={}", + delta.serial(), + delta.announced().len(), + delta.withdrawn().len(), + dropped_serial, + self.deltas.len(), + self.max_delta + ); + self.deltas.push_back(delta); + } + + fn replace_snapshot(&mut self, snapshot: Snapshot) { + self.snapshot = snapshot; + } + + fn delta_window(&self) -> Option<(u32, u32)> { + let min = self.deltas.front().map(|d| d.serial()); + let max = self.deltas.back().map(|d| d.serial()); + match (min, max) { + (Some(min), Some(max)) => Some((min, max)), + _ => None, + } + } + + pub(super) fn apply_update(&mut self, new_payloads: Vec) -> Result> { + self.last_update_begin = DualTime::now(); + info!( + "RTR cache applying update: availability={:?}, current_serial={}, incoming_payloads={}", + self.availability, + self.serial, + new_payloads.len() + ); + + let new_snapshot = Snapshot::from_payloads(new_payloads); + debug!( + "RTR cache built new snapshot from update: route_origins={}, router_keys={}, aspas={}, snapshot_empty={}", + new_snapshot.origins().len(), + new_snapshot.router_keys().len(), + new_snapshot.aspas().len(), + new_snapshot.is_empty() + ); + + if new_snapshot.is_empty() { + let changed = self.availability != CacheAvailability::NoDataAvailable + || !self.snapshot.is_empty() + || !self.deltas.is_empty(); + + self.set_unavailable(); + self.last_update_end = DualTime::now(); + + if !changed { + debug!("RTR cache update produced empty snapshot but cache was already unavailable; no state change"); + return Ok(None); + } + + info!( + "RTR cache update cleared usable data and marked cache unavailable: serial={}, session_ids={:?}", + self.serial, + self.session_ids + ); + + return Ok(Some(AppliedUpdate { + availability: self.availability, + snapshot: Snapshot::empty(), + serial: self.serial, + session_ids: self.session_ids.clone(), + delta: None, + delta_window: None, + clear_delta_window: true, + })); + } + + if self.availability == CacheAvailability::NoDataAvailable { + info!("RTR cache recovered from NoDataAvailable with non-empty snapshot"); + return Ok(Some(self.reinitialize_from_snapshot(new_snapshot))); + } + + if self.snapshot.same_content(&new_snapshot) { + self.last_update_end = DualTime::now(); + debug!( + "RTR cache update detected identical snapshot content: serial={}, session_ids={:?}", + self.serial, + self.session_ids + ); + return Ok(None); + } + + let (announced, withdrawn) = self.snapshot.diff(&new_snapshot); + debug!( + "RTR cache diff computed: announced={}, withdrawn={}, current_serial={}", + announced.len(), + withdrawn.len(), + self.serial + ); + + if announced.is_empty() && withdrawn.is_empty() { + self.last_update_end = DualTime::now(); + debug!("RTR cache diff was empty after normalization; no update applied"); + return Ok(None); + } + + let new_serial = self.next_serial(); + let delta = Arc::new(Delta::new(new_serial, announced, withdrawn)); + + if delta.is_empty() { + self.last_update_end = DualTime::now(); + debug!( + "RTR cache delta collapsed to empty after dedup/order normalization: serial={}", + new_serial + ); + return Ok(None); + } + + self.push_delta(delta.clone()); + self.replace_snapshot(new_snapshot.clone()); + self.last_update_end = DualTime::now(); + let delta_window = self.delta_window(); + info!( + "RTR cache applied update: serial={}, announced={}, withdrawn={}, delta_window={:?}, snapshot(route_origins={}, router_keys={}, aspas={})", + new_serial, + delta.announced().len(), + delta.withdrawn().len(), + delta_window, + new_snapshot.origins().len(), + new_snapshot.router_keys().len(), + new_snapshot.aspas().len() + ); + + Ok(Some(AppliedUpdate { + availability: self.availability, + snapshot: new_snapshot, + serial: new_serial, + session_ids: self.session_ids.clone(), + delta: Some(delta), + delta_window, + clear_delta_window: false, + })) + } + + pub fn is_data_available(&self) -> bool { + self.availability == CacheAvailability::Ready + } + + pub fn availability(&self) -> CacheAvailability { + self.availability + } + + pub fn session_id_for_version(&self, version: u8) -> u16 { + self.session_ids.get(version) + } + + pub fn session_ids(&self) -> SessionIds { + self.session_ids.clone() + } + + pub fn snapshot(&self) -> Snapshot { + self.snapshot.clone() + } + + pub fn serial(&self) -> u32 { + self.serial + } + + pub fn timing(&self) -> Timing { + self.timing + } + + pub fn current_snapshot_with_session_ids(&self) -> (&Snapshot, u32, SessionIds) { + (&self.snapshot, self.serial, self.session_ids.clone()) + } + + pub fn last_update_begin(&self) -> DualTime { + self.last_update_begin.clone() + } + + pub fn last_update_end(&self) -> DualTime { + self.last_update_end.clone() + } + + pub fn created_at(&self) -> DualTime { + self.created_at.clone() + } + + pub fn get_deltas_since(&self, client_serial: u32) -> SerialResult { + if client_serial == self.serial { + debug!( + "RTR cache delta query is already up to date: client_serial={}, cache_serial={}", + client_serial, + self.serial + ); + return SerialResult::UpToDate; + } + + if matches!( + serial_cmp(client_serial, self.serial), + Some(Ordering::Greater) | None + ) { + warn!( + "RTR cache delta query requires reset due to invalid/newer client serial: client_serial={}, cache_serial={}", + client_serial, + self.serial + ); + return SerialResult::ResetRequired; + } + + let deltas = match self.collect_deltas_since(client_serial) { + Some(deltas) => deltas, + None => { + warn!( + "RTR cache delta query requires reset because requested serial is outside delta window: client_serial={}, cache_serial={}, delta_window={:?}", + client_serial, + self.serial, + self.delta_window() + ); + return SerialResult::ResetRequired; + } + }; + + if deltas.is_empty() { + debug!( + "RTR cache delta query resolved to no deltas: client_serial={}, cache_serial={}", + client_serial, + self.serial + ); + return SerialResult::UpToDate; + } + + let merged = self.merge_deltas_minimally(&deltas); + + if merged.is_empty() { + debug!( + "RTR cache merged delta query to empty result: client_serial={}, cache_serial={}, source_deltas={}", + client_serial, + self.serial, + deltas.len() + ); + SerialResult::UpToDate + } else { + info!( + "RTR cache serving delta query: client_serial={}, cache_serial={}, source_deltas={}, merged_announced={}, merged_withdrawn={}", + client_serial, + self.serial, + deltas.len(), + merged.announced().len(), + merged.withdrawn().len() + ); + SerialResult::Delta(merged) + } + } + + fn collect_deltas_since(&self, client_serial: u32) -> Option>> { + if self.deltas.is_empty() { + return None; + } + + let oldest_serial = self.deltas.front().unwrap().serial(); + let min_supported = oldest_serial.wrapping_sub(1); + + if matches!( + serial_cmp(client_serial, min_supported), + Some(Ordering::Less) | None + ) { + return None; + } + + let mut result = Vec::new(); + for delta in &self.deltas { + if serial_gt(delta.serial(), client_serial) { + result.push(delta.clone()); + } + } + + if let Some(first) = result.first() { + if first.serial() != client_serial.wrapping_add(1) { + return None; + } + } + + Some(result) + } + + fn merge_deltas_minimally(&self, deltas: &[Arc]) -> Delta { + let mut states = BTreeMap::::new(); + + for delta in deltas { + for payload in delta.withdrawn() { + let key = change_key(payload); + let state = states.entry(key).or_insert_with(LogicalState::new); + + if state.before.is_none() && state.after.is_none() { + state.before = Some(payload.clone()); + } + state.after = None; + } + + for payload in delta.announced() { + let key = change_key(payload); + let state = states.entry(key).or_insert_with(LogicalState::new); + + state.after = Some(payload.clone()); + } + } + + let mut announced = Vec::new(); + let mut withdrawn = Vec::new(); + + for (_key, state) in states { + match (state.before, state.after) { + (None, None) => {} + (None, Some(new_payload)) => { + announced.push(new_payload); + } + (Some(old_payload), None) => { + withdrawn.push(old_payload); + } + (Some(old_payload), Some(new_payload)) => { + if old_payload != new_payload { + if matches!(old_payload, Payload::Aspa(_)) + && matches!(new_payload, Payload::Aspa(_)) + { + announced.push(new_payload); + } else { + withdrawn.push(old_payload); + announced.push(new_payload); + } + } + } + } + } + + Delta::new(self.serial, announced, withdrawn) + } +} + +#[derive(Debug, Clone, Default)] +struct LogicalState { + before: Option, + after: Option, +} + +impl LogicalState { + fn new() -> Self { + Self { + before: None, + after: None, + } + } +} + +pub enum SerialResult { + UpToDate, + Delta(Delta), + ResetRequired, +} + +pub(super) struct AppliedUpdate { + pub(super) availability: CacheAvailability, + pub(super) snapshot: Snapshot, + pub(super) serial: u32, + pub(super) session_ids: SessionIds, + pub(super) delta: Option>, + pub(super) delta_window: Option<(u32, u32)>, + pub(super) clear_delta_window: bool, +} + +fn serial_cmp(a: u32, b: u32) -> Option { + if a == b { + return Some(Ordering::Equal); + } + + let diff = a.wrapping_sub(b); + if diff == SERIAL_HALF_RANGE { + None + } else if diff < SERIAL_HALF_RANGE { + Some(Ordering::Greater) + } else { + Some(Ordering::Less) + } +} + +fn serial_gt(a: u32, b: u32) -> bool { + matches!(serial_cmp(a, b), Some(Ordering::Greater)) +} + +fn version_index(version: u8) -> usize { + match version { + 0..=2 => version as usize, + _ => panic!("unsupported RTR protocol version: {}", version), + } +} diff --git a/src/rtr/cache/mod.rs b/src/rtr/cache/mod.rs new file mode 100644 index 0000000..1a38470 --- /dev/null +++ b/src/rtr/cache/mod.rs @@ -0,0 +1,14 @@ +mod core; +mod model; +mod ordering; +mod store; + +pub use core::{CacheAvailability, RtrCache, RtrCacheBuilder, SerialResult, SessionIds}; +pub use model::{Delta, DualTime, Snapshot}; +pub use ordering::{ + OrderingViolation, validate_payload_updates_for_rtr, validate_payloads_for_rtr, +}; + +use std::sync::{Arc, RwLock}; + +pub type SharedRtrCache = Arc>; diff --git a/src/rtr/cache/model.rs b/src/rtr/cache/model.rs new file mode 100644 index 0000000..6f9c41b --- /dev/null +++ b/src/rtr/cache/model.rs @@ -0,0 +1,392 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::time::{Duration, Instant}; + +use chrono::{DateTime, NaiveDateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; + +use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey}; + +use super::ordering::{compare_payload_update_for_rtr, sort_payloads_for_rtr}; + +#[derive(Debug, Clone)] +pub struct DualTime { + instant: Instant, + utc: DateTime, +} + +impl DualTime { + pub fn now() -> Self { + Self { + instant: Instant::now(), + utc: Utc::now(), + } + } + + pub fn utc(&self) -> DateTime { + self.utc + } + + pub fn elapsed(&self) -> Duration { + self.instant.elapsed() + } + + pub fn is_expired(&self, duration: Duration) -> bool { + self.elapsed() >= duration + } + + pub fn reset(&mut self) { + self.instant = Instant::now(); + self.utc = Utc::now(); + } +} + +impl Serialize for DualTime { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.utc.timestamp_millis().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for DualTime { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let millis = i64::deserialize(deserializer)?; + let naive = NaiveDateTime::from_timestamp_millis(millis) + .ok_or_else(|| serde::de::Error::custom("invalid timestamp"))?; + let utc = DateTime::::from_utc(naive, Utc); + + Ok(Self { + instant: Instant::now(), + utc, + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Snapshot { + origins: BTreeSet, + router_keys: BTreeSet, + aspas: BTreeSet, + created_at: DualTime, + origins_hash: [u8; 32], + router_keys_hash: [u8; 32], + aspas_hash: [u8; 32], + snapshot_hash: [u8; 32], +} + +impl Snapshot { + pub fn new( + origins: BTreeSet, + router_keys: BTreeSet, + aspas: BTreeSet, + ) -> Self { + let mut snapshot = Snapshot { + origins, + router_keys, + aspas: normalize_aspas(aspas), + created_at: DualTime::now(), + origins_hash: [0u8; 32], + router_keys_hash: [0u8; 32], + aspas_hash: [0u8; 32], + snapshot_hash: [0u8; 32], + }; + snapshot.recompute_hashes(); + snapshot + } + + pub fn empty() -> Self { + Self::new(BTreeSet::new(), BTreeSet::new(), BTreeSet::new()) + } + + pub fn from_payloads(payloads: Vec) -> Self { + let mut origins = BTreeSet::new(); + let mut router_keys = BTreeSet::new(); + let mut aspas = Vec::new(); + + for p in payloads { + match p { + Payload::RouteOrigin(o) => { + origins.insert(o); + } + Payload::RouterKey(k) => { + router_keys.insert(k); + } + Payload::Aspa(a) => { + aspas.push(a); + } + } + } + + Snapshot::new(origins, router_keys, normalize_aspas(aspas)) + } + + pub fn recompute_hashes(&mut self) { + self.origins_hash = self.compute_origins_hash(); + self.router_keys_hash = self.compute_router_keys_hash(); + self.aspas_hash = self.compute_aspas_hash(); + self.snapshot_hash = self.compute_snapshot_hash(); + } + + fn compute_origins_hash(&self) -> [u8; 32] { + Self::hash_ordered_iter(self.origins.iter()) + } + + fn compute_router_keys_hash(&self) -> [u8; 32] { + Self::hash_ordered_iter(self.router_keys.iter()) + } + + fn compute_aspas_hash(&self) -> [u8; 32] { + Self::hash_ordered_iter(self.aspas.iter()) + } + + fn compute_snapshot_hash(&self) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(b"snapshot:v1"); + hasher.update(self.origins_hash); + hasher.update(self.router_keys_hash); + hasher.update(self.aspas_hash); + hasher.finalize().into() + } + + fn hash_ordered_iter<'a, T, I>(iter: I) -> [u8; 32] + where + T: Serialize + 'a, + I: IntoIterator, + { + let mut hasher = Sha256::new(); + hasher.update(b"set:v1"); + + for item in iter { + let encoded = + serde_json::to_vec(item).expect("serialize snapshot item for hashing failed"); + let len = (encoded.len() as u32).to_be_bytes(); + hasher.update(len); + hasher.update(encoded); + } + + hasher.finalize().into() + } + + pub fn diff(&self, new_snapshot: &Snapshot) -> (Vec, Vec) { + let mut announced = Vec::new(); + let mut withdrawn = Vec::new(); + + if !self.same_origins(new_snapshot) { + for origin in new_snapshot.origins.difference(&self.origins) { + announced.push(Payload::RouteOrigin(origin.clone())); + } + for origin in self.origins.difference(&new_snapshot.origins) { + withdrawn.push(Payload::RouteOrigin(origin.clone())); + } + } + + if !self.same_router_keys(new_snapshot) { + for key in new_snapshot.router_keys.difference(&self.router_keys) { + announced.push(Payload::RouterKey(key.clone())); + } + for key in self.router_keys.difference(&new_snapshot.router_keys) { + withdrawn.push(Payload::RouterKey(key.clone())); + } + } + + if !self.same_aspas(new_snapshot) { + diff_aspas(&self.aspas, &new_snapshot.aspas, &mut announced, &mut withdrawn); + } + + (announced, withdrawn) + } + + pub fn created_at(&self) -> DualTime { + self.created_at.clone() + } + + pub fn payloads(&self) -> Vec { + let mut v = Vec::with_capacity( + self.origins.len() + self.router_keys.len() + self.aspas.len(), + ); + + v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin)); + v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey)); + v.extend(self.aspas.iter().cloned().map(Payload::Aspa)); + + v + } + + pub fn payloads_for_rtr(&self) -> Vec { + let mut payloads = self.payloads(); + sort_payloads_for_rtr(&mut payloads, true); + payloads + } + + pub fn origins_hash(&self) -> [u8; 32] { + self.origins_hash + } + + pub fn router_keys_hash(&self) -> [u8; 32] { + self.router_keys_hash + } + + pub fn aspas_hash(&self) -> [u8; 32] { + self.aspas_hash + } + + pub fn snapshot_hash(&self) -> [u8; 32] { + self.snapshot_hash + } + + pub fn same_origins(&self, other: &Self) -> bool { + self.origins_hash == other.origins_hash + } + + pub fn same_router_keys(&self, other: &Self) -> bool { + self.router_keys_hash == other.router_keys_hash + } + + pub fn same_aspas(&self, other: &Self) -> bool { + self.aspas_hash == other.aspas_hash + } + + pub fn same_content(&self, other: &Self) -> bool { + self.snapshot_hash == other.snapshot_hash + } + + pub fn origins(&self) -> &BTreeSet { + &self.origins + } + + pub fn router_keys(&self) -> &BTreeSet { + &self.router_keys + } + + pub fn aspas(&self) -> &BTreeSet { + &self.aspas + } + + pub fn is_empty(&self) -> bool { + self.origins.is_empty() + && self.router_keys.is_empty() + && self.aspas.is_empty() + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Delta { + serial: u32, + announced: Vec, + withdrawn: Vec, + created_at: DualTime, +} + +impl Delta { + pub fn new(serial: u32, mut announced: Vec, mut withdrawn: Vec) -> Self { + dedup_payloads(&mut announced); + dedup_payloads(&mut withdrawn); + + sort_payloads_for_rtr(&mut announced, true); + sort_payloads_for_rtr(&mut withdrawn, false); + + Delta { + serial, + announced, + withdrawn, + created_at: DualTime::now(), + } + } + + pub fn serial(&self) -> u32 { + self.serial + } + + pub fn announced(&self) -> &[Payload] { + &self.announced + } + + pub fn withdrawn(&self) -> &[Payload] { + &self.withdrawn + } + + pub fn created_at(&self) -> DualTime { + self.created_at.clone() + } + + pub fn is_empty(&self) -> bool { + self.announced.is_empty() && self.withdrawn.is_empty() + } + + pub fn payloads_for_rtr(&self) -> Vec<(bool, Payload)> { + let mut updates = Vec::with_capacity(self.announced.len() + self.withdrawn.len()); + + updates.extend(self.announced.iter().cloned().map(|p| (true, p))); + updates.extend(self.withdrawn.iter().cloned().map(|p| (false, p))); + + updates.sort_by(|(a_upd, a_payload), (b_upd, b_payload)| { + compare_payload_update_for_rtr(a_payload, *a_upd, b_payload, *b_upd) + }); + + updates + } +} + +fn dedup_payloads(payloads: &mut Vec) { + let mut seen = BTreeSet::new(); + payloads.retain(|p| seen.insert(p.clone())); +} + +fn normalize_aspas(aspas: I) -> BTreeSet +where + I: IntoIterator, +{ + let mut by_customer = BTreeMap::>::new(); + + for aspa in aspas { + let providers = by_customer + .entry(aspa.customer_asn().into_u32()) + .or_default(); + providers.extend(aspa.provider_asns().iter().copied()); + } + + by_customer + .into_iter() + .map(|(customer_asn, providers)| { + Aspa::new(customer_asn.into(), providers.into_iter().collect()) + }) + .collect() +} + +fn diff_aspas( + current: &BTreeSet, + next: &BTreeSet, + announced: &mut Vec, + withdrawn: &mut Vec, +) { + let current = current + .iter() + .map(|aspa| (aspa.customer_asn().into_u32(), aspa)) + .collect::>(); + let next = next + .iter() + .map(|aspa| (aspa.customer_asn().into_u32(), aspa)) + .collect::>(); + + let customers = current + .keys() + .chain(next.keys()) + .copied() + .collect::>(); + + for customer in customers { + match (current.get(&customer), next.get(&customer)) { + (None, Some(new_aspa)) => announced.push(Payload::Aspa((*new_aspa).clone())), + (Some(old_aspa), None) => withdrawn.push(Payload::Aspa((*old_aspa).clone())), + (Some(old_aspa), Some(new_aspa)) if old_aspa != new_aspa => { + announced.push(Payload::Aspa((*new_aspa).clone())); + } + _ => {} + } + } +} diff --git a/src/rtr/cache/ordering.rs b/src/rtr/cache/ordering.rs new file mode 100644 index 0000000..75b7392 --- /dev/null +++ b/src/rtr/cache/ordering.rs @@ -0,0 +1,311 @@ +use std::cmp::Ordering; +use std::fmt; + +use crate::data_model::resources::ip_resources::IPAddress; +use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Ski}; + +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct OrderingViolation { + index: usize, + left: String, + right: String, +} + +impl OrderingViolation { + fn new(index: usize, left: &Payload, right: &Payload) -> Self { + Self { + index, + left: payload_brief(left), + right: payload_brief(right), + } + } + + fn new_update(index: usize, left: (bool, &Payload), right: (bool, &Payload)) -> Self { + Self { + index, + left: payload_update_brief(left.0, left.1), + right: payload_update_brief(right.0, right.1), + } + } +} + +impl fmt::Display for OrderingViolation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "RTR payload ordering violation at positions {} and {}: {} should not appear before {}", + self.index, + self.index + 1, + self.left, + self.right + ) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub(crate) enum RouterKeyKey { + Key { ski: Ski, spki: Vec, asn: u32 }, +} + +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub(crate) enum ChangeKey { + RouteOrigin(RouteOriginKey), + RouterKey(RouterKeyKey), + AspaCustomer(u32), +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] +enum PayloadPduType { + Ipv4Prefix = 4, + Ipv6Prefix = 6, + RouterKey = 9, + Aspa = 11, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +pub(crate) enum RouteOriginKey { + V4 { addr: u32, plen: u8, mlen: u8, asn: u32 }, + V6 { addr: u128, plen: u8, mlen: u8, asn: u32 }, +} + +pub(crate) fn change_key(payload: &Payload) -> ChangeKey { + match payload { + Payload::RouteOrigin(ro) => ChangeKey::RouteOrigin(route_origin_key(ro)), + Payload::RouterKey(rk) => ChangeKey::RouterKey(router_key_key(rk)), + Payload::Aspa(aspa) => ChangeKey::AspaCustomer(aspa.customer_asn().into_u32()), + } +} + +pub(crate) fn compare_payload_update_for_rtr( + a_payload: &Payload, + a_announce: bool, + b_payload: &Payload, + b_announce: bool, +) -> Ordering { + let type_a = payload_pdu_type(a_payload); + let type_b = payload_pdu_type(b_payload); + + match type_a.cmp(&type_b) { + Ordering::Equal => {} + other => return other, + } + + match b_announce.cmp(&a_announce) { + Ordering::Equal => {} + other => return other, + } + + match (a_payload, b_payload) { + (Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => { + compare_route_origin_for_rtr(a, b, a_announce) + } + (Payload::RouterKey(a), Payload::RouterKey(b)) => compare_router_key_for_rtr(a, b), + (Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b), + _ => Ordering::Equal, + } +} + +pub(crate) fn sort_payloads_for_rtr(payloads: &mut [Payload], announce: bool) { + payloads.sort_by(|a, b| compare_payload_for_rtr(a, b, announce)); +} + +pub fn validate_payloads_for_rtr( + payloads: &[Payload], + announce: bool, +) -> Result<(), OrderingViolation> { + for (index, pair) in payloads.windows(2).enumerate() { + if compare_payload_for_rtr(&pair[0], &pair[1], announce) == Ordering::Greater { + return Err(OrderingViolation::new(index, &pair[0], &pair[1])); + } + } + + Ok(()) +} + +pub fn validate_payload_updates_for_rtr( + updates: &[(bool, Payload)], +) -> Result<(), OrderingViolation> { + for (index, pair) in updates.windows(2).enumerate() { + if compare_payload_update_for_rtr(&pair[0].1, pair[0].0, &pair[1].1, pair[1].0) + == Ordering::Greater + { + return Err(OrderingViolation::new_update( + index, + (pair[0].0, &pair[0].1), + (pair[1].0, &pair[1].1), + )); + } + } + + Ok(()) +} + +fn router_key_key(rk: &RouterKey) -> RouterKeyKey { + RouterKeyKey::Key { + ski: rk.ski(), + spki: rk.spki().as_ref().to_vec(), + asn: rk.asn().into_u32(), + } +} + +fn compare_payload_for_rtr(a: &Payload, b: &Payload, announce: bool) -> Ordering { + let type_a = payload_pdu_type(a); + let type_b = payload_pdu_type(b); + + match type_a.cmp(&type_b) { + Ordering::Equal => {} + other => return other, + } + + match (a, b) { + (Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => { + compare_route_origin_for_rtr(a, b, announce) + } + (Payload::RouterKey(a), Payload::RouterKey(b)) => compare_router_key_for_rtr(a, b), + (Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b), + _ => Ordering::Equal, + } +} + +fn payload_pdu_type(payload: &Payload) -> PayloadPduType { + match payload { + Payload::RouteOrigin(ro) => { + if route_origin_is_ipv4(ro) { + PayloadPduType::Ipv4Prefix + } else { + PayloadPduType::Ipv6Prefix + } + } + Payload::RouterKey(_) => PayloadPduType::RouterKey, + Payload::Aspa(_) => PayloadPduType::Aspa, + } +} + +fn route_origin_is_ipv4(ro: &RouteOrigin) -> bool { + ro.prefix().address.is_ipv4() +} + +fn route_origin_key(ro: &RouteOrigin) -> RouteOriginKey { + let prefix = ro.prefix(); + let plen = prefix.prefix_length; + let mlen = ro.max_length(); + let asn = ro.asn().into_u32(); + + match prefix.address { + IPAddress::V4(addr) => RouteOriginKey::V4 { + addr: u32::from(addr), + plen, + mlen, + asn, + }, + IPAddress::V6(addr) => RouteOriginKey::V6 { + addr: u128::from(addr), + plen, + mlen, + asn, + }, + } +} + +fn compare_route_origin_for_rtr(a: &RouteOrigin, b: &RouteOrigin, announce: bool) -> Ordering { + match (route_origin_key(a), route_origin_key(b)) { + ( + RouteOriginKey::V4 { + addr: addr_a, + plen: plen_a, + mlen: mlen_a, + asn: asn_a, + }, + RouteOriginKey::V4 { + addr: addr_b, + plen: plen_b, + mlen: mlen_b, + asn: asn_b, + }, + ) => { + if announce { + addr_b + .cmp(&addr_a) + .then_with(|| mlen_b.cmp(&mlen_a)) + .then_with(|| plen_b.cmp(&plen_a)) + .then_with(|| asn_b.cmp(&asn_a)) + } else { + addr_a + .cmp(&addr_b) + .then_with(|| mlen_a.cmp(&mlen_b)) + .then_with(|| plen_a.cmp(&plen_b)) + .then_with(|| asn_a.cmp(&asn_b)) + } + } + ( + RouteOriginKey::V6 { + addr: addr_a, + plen: plen_a, + mlen: mlen_a, + asn: asn_a, + }, + RouteOriginKey::V6 { + addr: addr_b, + plen: plen_b, + mlen: mlen_b, + asn: asn_b, + }, + ) => { + if announce { + addr_b + .cmp(&addr_a) + .then_with(|| mlen_b.cmp(&mlen_a)) + .then_with(|| plen_b.cmp(&plen_a)) + .then_with(|| asn_b.cmp(&asn_a)) + } else { + addr_a + .cmp(&addr_b) + .then_with(|| mlen_a.cmp(&mlen_b)) + .then_with(|| plen_a.cmp(&plen_b)) + .then_with(|| asn_a.cmp(&asn_b)) + } + } + _ => Ordering::Equal, + } +} + +fn compare_router_key_for_rtr(a: &RouterKey, b: &RouterKey) -> Ordering { + a.ski() + .cmp(&b.ski()) + .then_with(|| a.spki().len().cmp(&b.spki().len())) + .then_with(|| a.spki().cmp(b.spki())) + .then_with(|| a.asn().into_u32().cmp(&b.asn().into_u32())) +} + +fn compare_aspa_for_rtr(a: &Aspa, b: &Aspa) -> Ordering { + a.customer_asn() + .into_u32() + .cmp(&b.customer_asn().into_u32()) +} + +fn payload_brief(payload: &Payload) -> String { + match payload { + Payload::RouteOrigin(origin) => format!( + "{} prefix {:?}/{} max={} asn={}", + if route_origin_is_ipv4(origin) { "IPv4" } else { "IPv6" }, + origin.prefix().address, + origin.prefix().prefix_length, + origin.max_length(), + origin.asn().into_u32() + ), + Payload::RouterKey(key) => format!( + "RouterKey ski={:02x?} asn={}", + key.ski().as_ref(), + key.asn().into_u32() + ), + Payload::Aspa(aspa) => format!("ASPA customer_asn={}", aspa.customer_asn().into_u32()), + } +} + +fn payload_update_brief(announce: bool, payload: &Payload) -> String { + format!( + "{} {}", + if announce { "announce" } else { "withdraw" }, + payload_brief(payload) + ) +} diff --git a/src/rtr/cache/store.rs b/src/rtr/cache/store.rs new file mode 100644 index 0000000..4efc655 --- /dev/null +++ b/src/rtr/cache/store.rs @@ -0,0 +1,195 @@ +use std::collections::VecDeque; +use std::sync::Arc; + +use anyhow::Result; + +use crate::rtr::payload::{Payload, Timing}; +use crate::rtr::store::RtrStore; + +use super::core::{AppliedUpdate, CacheAvailability, RtrCache, RtrCacheBuilder, SessionIds}; +use super::model::Snapshot; + +impl RtrCache { + pub fn init( + self, + store: &RtrStore, + max_delta: u8, + timing: Timing, + file_loader: impl Fn() -> Result>, + ) -> Result { + if let Some(cache) = try_restore_from_store(store, max_delta, timing)? { + tracing::info!( + "RTR cache restored from store: availability={:?}, session_ids={:?}, serial={}, snapshot(route_origins={}, router_keys={}, aspas={})", + cache.availability(), + cache.session_ids(), + cache.serial(), + cache.snapshot().origins().len(), + cache.snapshot().router_keys().len(), + cache.snapshot().aspas().len() + ); + return Ok(cache); + } + + tracing::warn!("RTR cache store unavailable or invalid, fallback to file loader"); + + let payloads = file_loader()?; + let session_ids = SessionIds::random_distinct(); + let snapshot = Snapshot::from_payloads(payloads); + let availability = if snapshot.is_empty() { + CacheAvailability::NoDataAvailable + } else { + CacheAvailability::Ready + }; + let serial = if snapshot.is_empty() { 0 } else { 1 }; + + if snapshot.is_empty() { + tracing::warn!( + "RTR cache initialized without usable data: session_ids={:?}, serial={}", + session_ids, + serial + ); + } else { + tracing::info!( + "RTR cache initialized from file loader: session_ids={:?}, serial={}", + session_ids, + serial + ); + } + + let snapshot_for_store = snapshot.clone(); + let session_ids_for_store = session_ids.clone(); + + tokio::spawn({ + let store = store.clone(); + async move { + if let Err(e) = store.save_cache_state( + availability, + &snapshot_for_store, + &session_ids_for_store, + serial, + None, + None, + true, + ) { + tracing::error!("persist cache state failed: {:?}", e); + } + } + }); + + Ok(RtrCacheBuilder::new() + .availability(availability) + .session_ids(session_ids) + .max_delta(max_delta) + .timing(timing) + .serial(serial) + .snapshot(snapshot) + .build()) + } + + pub fn update(&mut self, new_payloads: Vec, store: &RtrStore) -> Result<()> { + if let Some(update) = self.apply_update(new_payloads)? { + spawn_store_sync(store, update); + } + + Ok(()) + } +} + +fn try_restore_from_store(store: &RtrStore, max_delta: u8, timing: Timing) -> Result> { + let snapshot = store.get_snapshot()?; + let session_ids = store.get_session_ids()?; + let serial = store.get_serial()?; + let availability = store.get_availability()?; + + let (snapshot, session_ids, serial) = match (snapshot, session_ids, serial) { + (Some(snapshot), Some(session_ids), Some(serial)) => (snapshot, session_ids, serial), + _ => { + tracing::warn!("RTR cache store incomplete: snapshot/session_ids/serial missing"); + return Ok(None); + } + }; + + let availability = availability.unwrap_or_else(|| { + tracing::warn!("RTR cache store missing availability metadata, defaulting to Ready"); + CacheAvailability::Ready + }); + + let deltas = if availability == CacheAvailability::NoDataAvailable { + tracing::warn!("RTR cache store restored in no-data-available state"); + VecDeque::with_capacity(max_delta.into()) + } else { + match store.get_delta_window()? { + Some((min_serial, max_serial)) => { + match store.load_delta_window(min_serial, max_serial) { + Ok(deltas) => deltas.into_iter().map(Arc::new).collect(), + Err(err) => { + tracing::warn!( + "RTR cache store delta recovery failed, treat store as unusable: {:?}", + err + ); + return Ok(None); + } + } + } + None => { + tracing::info!("RTR cache store has no delta window, restore snapshot only"); + VecDeque::with_capacity(max_delta.into()) + } + } + }; + + Ok(Some( + RtrCacheBuilder::new() + .availability(availability) + .session_ids(session_ids) + .max_delta(max_delta) + .timing(timing) + .serial(serial) + .snapshot(snapshot) + .deltas(deltas) + .build(), + )) +} + +fn spawn_store_sync(store: &RtrStore, update: AppliedUpdate) { + let AppliedUpdate { + availability, + snapshot, + serial, + session_ids, + delta, + delta_window, + clear_delta_window, + } = update; + + tokio::spawn({ + let store = store.clone(); + async move { + tracing::debug!( + "persisting RTR cache state: availability={:?}, serial={}, session_ids={:?}, delta_present={}, delta_window={:?}, clear_delta_window={}, snapshot(route_origins={}, router_keys={}, aspas={})", + availability, + serial, + session_ids, + delta.is_some(), + delta_window, + clear_delta_window, + snapshot.origins().len(), + snapshot.router_keys().len(), + snapshot.aspas().len() + ); + if let Err(e) = store.save_cache_state( + availability, + &snapshot, + &session_ids, + serial, + delta.as_deref(), + delta_window, + clear_delta_window, + ) { + tracing::error!("persist cache state failed: {:?}", e); + } else { + tracing::debug!("persist RTR cache state completed: serial={}", serial); + } + } + }); +} diff --git a/src/rtr/mod.rs b/src/rtr/mod.rs index df72a99..a085a4a 100644 --- a/src/rtr/mod.rs +++ b/src/rtr/mod.rs @@ -1,9 +1,9 @@ pub mod pdu; pub mod cache; pub mod payload; -pub mod store_db; +pub mod store; pub mod session; pub mod error_type; pub mod state; pub mod server; -pub mod loader; \ No newline at end of file +pub mod loader; diff --git a/src/rtr/payload.rs b/src/rtr/payload.rs index 42f24cc..c1f393e 100644 --- a/src/rtr/payload.rs +++ b/src/rtr/payload.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::io; use std::time::Duration; use serde::{Deserialize, Serialize}; use crate::data_model::resources::as_resources::Asn; @@ -16,6 +17,11 @@ enum PayloadPduType { #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] pub struct Ski([u8; 20]); +impl AsRef<[u8]> for Ski { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] pub struct RouteOrigin { @@ -100,10 +106,35 @@ impl Aspa { pub fn provider_asns(&self) -> &[Asn] { &self.provider_asns } + + pub fn validate_announcement(&self) -> Result<(), io::Error> { + if self.customer_asn.into_u32() == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ASPA customer ASN must not be AS0", + )); + } + + if self.provider_asns.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ASPA announcement must contain at least one provider ASN", + )); + } + + if self.provider_asns.iter().any(|asn| asn.into_u32() == 0) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ASPA provider list must not contain AS0", + )); + } + + Ok(()) + } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub enum Payload { /// A route origin. @@ -131,10 +162,71 @@ pub struct Timing { } impl Timing { + pub const MIN_REFRESH: u32 = 1; + pub const MAX_REFRESH: u32 = 86_400; + pub const MIN_RETRY: u32 = 1; + pub const MAX_RETRY: u32 = 7_200; + pub const MIN_EXPIRE: u32 = 600; + pub const MAX_EXPIRE: u32 = 172_800; + pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self { Self { refresh, retry, expire } } + pub fn validate(self) -> Result<(), io::Error> { + if !(Self::MIN_REFRESH..=Self::MAX_REFRESH).contains(&self.refresh) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "refresh interval {} out of range {}..={}", + self.refresh, Self::MIN_REFRESH, Self::MAX_REFRESH + ), + )); + } + + if !(Self::MIN_RETRY..=Self::MAX_RETRY).contains(&self.retry) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "retry interval {} out of range {}..={}", + self.retry, Self::MIN_RETRY, Self::MAX_RETRY + ), + )); + } + + if !(Self::MIN_EXPIRE..=Self::MAX_EXPIRE).contains(&self.expire) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "expire interval {} out of range {}..={}", + self.expire, Self::MIN_EXPIRE, Self::MAX_EXPIRE + ), + )); + } + + if self.expire <= self.refresh { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "expire interval {} must be greater than refresh interval {}", + self.expire, self.refresh + ), + )); + } + + if self.expire <= self.retry { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "expire interval {} must be greater than retry interval {}", + self.expire, self.retry + ), + )); + } + + Ok(()) + } + pub fn refresh(self) -> Duration { Duration::from_secs(u64::from(self.refresh)) } @@ -157,4 +249,4 @@ impl Default for Timing { expire: 7200, } } -} \ No newline at end of file +} diff --git a/src/rtr/pdu.rs b/src/rtr/pdu.rs index 7edc1e1..c7bdf20 100644 --- a/src/rtr/pdu.rs +++ b/src/rtr/pdu.rs @@ -2,9 +2,9 @@ use std::{cmp, mem}; use std::net::{Ipv4Addr, Ipv6Addr}; use std::sync::Arc; use crate::data_model::resources::as_resources::Asn; +use crate::rtr::error_type::ErrorCode; use crate::rtr::payload::{Ski, Timing}; use std::io; -use std::io::Write; use tokio::io::{AsyncWrite}; use anyhow::Result; @@ -12,7 +12,6 @@ use std::slice; use anyhow::bail; use serde::Serialize; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; pub const HEADER_LEN: usize = 8; pub const MAX_PDU_LEN: u32 = 65535; @@ -208,10 +207,15 @@ impl Header { } } - pub async fn read(sock: &mut S) -> Result { + pub async fn read_raw( + sock: &mut S + ) -> Result<[u8; HEADER_LEN], io::Error> { let mut buf = [0u8; HEADER_LEN]; sock.read_exact(&mut buf).await?; + Ok(buf) + } + pub fn from_raw(buf: [u8; HEADER_LEN]) -> Result { let version = buf[0]; let pdu = buf[1]; let session_id = u16::from_be_bytes([buf[2], buf[3]]); @@ -239,6 +243,10 @@ impl Header { }) } + pub async fn read(sock: &mut S) -> Result { + Self::from_raw(Self::read_raw(sock).await?) + } + pub fn version(self) -> u8{self.version} pub fn pdu(self) -> u8{self.pdu} @@ -256,6 +264,10 @@ impl Header { }) } + pub fn error_code(self) -> u16 { + debug_assert_eq!(self.pdu(), ErrorReport::PDU); + self.session_id() + } } @@ -281,7 +293,7 @@ impl HeaderWithFlags { length: length.to_be(), } } - pub async fn read(sock: &mut TcpStream) -> Result { + pub async fn read(sock: &mut S) -> Result { let mut buf = [0u8; HEADER_LEN]; // 1. 精确读取 8 字节 @@ -311,7 +323,7 @@ impl HeaderWithFlags { pdu, flags, zero, - length, + length: length.to_be(), }) } @@ -344,7 +356,7 @@ impl SerialNotify { } pub fn serial_number(self) -> u32 { - self.serial_number + u32::from_be(self.serial_number) } } @@ -370,7 +382,7 @@ impl SerialQuery { } pub fn serial_number(self) -> u32 { - self.serial_number + u32::from_be(self.serial_number) } } @@ -538,12 +550,17 @@ pub enum EndOfData { } impl EndOfData { - pub fn new(version: u8, session_id: u16, serial_number: u32, timing: Timing) -> Self { + pub fn new( + version: u8, + session_id: u16, + serial_number: u32, + timing: Timing, + ) -> Result { if version == 0 { - EndOfData::V0(EndOfDataV0::new(version, session_id, serial_number)) + Ok(EndOfData::V0(EndOfDataV0::new(version, session_id, serial_number))) } else { - EndOfData::V1(EndOfDataV1::new(version, session_id, serial_number, timing)) + Ok(EndOfData::V1(EndOfDataV1::new(version, session_id, serial_number, timing)?)) } } @@ -588,14 +605,37 @@ pub struct EndOfDataV1 { impl EndOfDataV1 { pub const PDU: u8 = 7; - pub fn new(version: u8, session_id: u16, serial_number: u32, timing: Timing) -> Self { - EndOfDataV1 { + pub fn version(&self) -> u8 { + self.header.version() + } + + pub fn session_id(&self) -> u16 { + self.header.session_id() + } + + pub fn pdu(&self) -> u8 { + self.header.pdu() + } + + pub fn size() -> u32 { + mem::size_of::() as u32 + } + + pub fn new( + version: u8, + session_id: u16, + serial_number: u32, + timing: Timing, + ) -> Result { + timing.validate()?; + + Ok(EndOfDataV1 { header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V1_LEN), serial_number: serial_number.to_be(), refresh_interval: timing.refresh.to_be(), retry_interval: timing.retry.to_be(), expire_interval: timing.expire.to_be(), - } + }) } pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} @@ -607,8 +647,50 @@ impl EndOfDataV1 { expire: u32::from_be(self.expire_interval), } } + + fn validate(&self) -> Result<(), io::Error> { + self.timing().validate() + } + + pub async fn read( + sock: &mut Sock + ) -> Result { + let mut res = Self::default(); + sock.read_exact(res.header.as_mut()).await?; + if res.header.pdu() != Self::PDU { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "PDU type mismatch when expecting EndOfDataV1", + )) + } + if res.header.length() as usize != mem::size_of::() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid length for EndOfDataV1", + )) + } + sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; + res.validate()?; + Ok(res) + } + + pub async fn read_payload( + header: Header, sock: &mut Sock + ) -> Result { + if header.length() as usize != mem::size_of::() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid length for EndOfDataV1 PDU", + )) + } + let mut res = Self::default(); + sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; + res.header = header; + res.validate()?; + Ok(res) + } } -concrete!(EndOfDataV1); +common!(EndOfDataV1); // Cache Reset #[repr(C, packed)] @@ -640,6 +722,7 @@ pub struct ErrorReport { impl ErrorReport { /// The PDU type of an error PDU. pub const PDU: u8 = 10; + const FIXED_PART_LEN: usize = HEADER_LEN + 2 * mem::size_of::(); /// Creates a new error PDU from components. pub fn new( @@ -650,12 +733,12 @@ impl ErrorReport { ) -> Self { let pdu = pdu.as_ref(); let text = text.as_ref(); + let max_payload_len = MAX_PDU_LEN as usize - Self::FIXED_PART_LEN; + let pdu_len = cmp::min(pdu.len(), max_payload_len); + let text_room = max_payload_len - pdu_len; + let text_len = cmp::min(text.len(), text_room); - let size = - mem::size_of::
() - + 2 * mem::size_of::() - + pdu.len() + text.len() - ; + let size = Self::FIXED_PART_LEN + pdu_len + text_len; let header = Header::new( version, 10, error_code, u32::try_from(size).unwrap() ); @@ -663,37 +746,92 @@ impl ErrorReport { let mut octets = Vec::with_capacity(size); octets.extend_from_slice(header.as_ref()); octets.extend_from_slice( - u32::try_from(pdu.len()).unwrap().to_be_bytes().as_ref() + u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref() ); - octets.extend_from_slice(pdu); + octets.extend_from_slice(&pdu[..pdu_len]); octets.extend_from_slice( - u32::try_from(text.len()).unwrap().to_be_bytes().as_ref() + u32::try_from(text_len).unwrap().to_be_bytes().as_ref() ); - octets.extend_from_slice(text); + octets.extend_from_slice(&text[..text_len]); ErrorReport { octets } } + pub async fn read( + sock: &mut Sock + ) -> Result { + let header = Header::read(sock).await?; + if header.pdu() != Self::PDU { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "PDU type mismatch when expecting ErrorReport", + )); + } + Self::read_payload(header, sock).await + } + + pub async fn read_payload( + header: Header, + sock: &mut Sock, + ) -> Result { + let total_len = header.pdu_len()?; + let Some(payload_len) = total_len.checked_sub(mem::size_of::
()) else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "PDU size smaller than header size", + )); + }; + + let mut octets = Vec::with_capacity(total_len); + octets.extend_from_slice(header.as_ref()); + octets.resize(total_len, 0); + sock.read_exact(&mut octets[mem::size_of::
()..]).await?; + + let res = ErrorReport { octets }; + res.validate()?; + debug_assert_eq!(payload_len + mem::size_of::
(), res.octets.len()); + Ok(res) + } + + pub fn version(&self) -> u8 { + self.header().version() + } + + pub fn error_code(&self) -> Result { + ErrorCode::try_from(self.header().error_code()).map_err(|_| self.header().error_code()) + } + + pub fn erroneous_pdu(&self) -> &[u8] { + &self.octets[self.erroneous_pdu_range()] + } + + pub fn text(&self) -> &[u8] { + &self.octets[self.text_range()] + } + /// Skips over the payload of the error PDU. pub async fn skip_payload( header: Header, sock: &mut Sock ) -> Result<(), io::Error> { - let Some(mut remaining) = header.pdu_len()?.checked_sub( - mem::size_of::
() - ) else { + let Some(mut remaining) = header.pdu_len()?.checked_sub(mem::size_of::
()) else { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU size smaller than header size", - )) + )); }; let mut buf = [0u8; 1024]; while remaining > 0 { let read_len = cmp::min(remaining, mem::size_of_val(&buf)); - let read = sock.read( - // Safety: We limited the length to the buffer size. - unsafe { buf.get_unchecked_mut(..read_len) } - ).await?; + let read = sock.read(&mut buf[..read_len]).await?; + + if read == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF while skipping ErrorReport payload", + )); + } + remaining -= read; } Ok(()) @@ -705,11 +843,115 @@ impl ErrorReport { ) -> Result<(), io::Error> { a.write_all(self.as_ref()).await } + + fn header(&self) -> Header { + Header::from_raw(self.header_bytes()).expect("validated ErrorReport header") + } + + fn header_bytes(&self) -> [u8; HEADER_LEN] { + self.octets[..HEADER_LEN] + .try_into() + .expect("ErrorReport shorter than header") + } + + fn erroneous_pdu_len(&self) -> usize { + u32::from_be_bytes( + self.octets[Header::LEN..Header::LEN + 4] + .try_into() + .unwrap() + ) as usize + } + + fn erroneous_pdu_range(&self) -> std::ops::Range { + let start = Header::LEN + 4; + let end = start + self.erroneous_pdu_len(); + start..end + } + + fn text_len_offset(&self) -> usize { + self.erroneous_pdu_range().end + } + + fn text_len(&self) -> usize { + let offset = self.text_len_offset(); + u32::from_be_bytes( + self.octets[offset..offset + 4] + .try_into() + .unwrap() + ) as usize + } + + fn text_range(&self) -> std::ops::Range { + let start = self.text_len_offset() + 4; + let end = start + self.text_len(); + start..end + } + + fn validate(&self) -> Result<(), io::Error> { + let header = Header::from_raw(self.header_bytes())?; + if header.pdu() != Self::PDU { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected PDU type for ErrorReport", + )); + } + + let total_len = header.pdu_len()?; + if total_len != self.octets.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ErrorReport length mismatch", + )); + } + + if self.octets.len() < Header::LEN + 8 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ErrorReport too short", + )); + } + + let pdu_len = self.erroneous_pdu_len(); + let text_len_offset = Header::LEN + 4 + pdu_len; + let Some(text_len_end) = text_len_offset.checked_add(4) else { + return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport length overflow")); + }; + if text_len_end > self.octets.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ErrorReport truncated before error text length", + )); + } + + let text_len = u32::from_be_bytes( + self.octets[text_len_offset..text_len_end] + .try_into() + .unwrap() + ) as usize; + let Some(text_end) = text_len_end.checked_add(text_len) else { + return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport text overflow")); + }; + if text_end != self.octets.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ErrorReport payload length mismatch", + )); + } + + if std::str::from_utf8(self.text()).is_err() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ErrorReport text is not valid UTF-8", + )); + } + + Ok(()) + } } // TODO: 补全 -// Router Key +/// Router Key #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct RouterKey { header: HeaderWithFlags, @@ -724,17 +966,13 @@ pub struct RouterKey { impl RouterKey { pub const PDU: u8 = 9; + const BASE_LEN: usize = HEADER_LEN + 20 + 4; pub async fn write( &self, w: &mut A, ) -> Result<(), io::Error> { - - let length = HEADER_LEN - + 1 // flags - // + self.ski.as_ref().len() - + 4 // ASN - + self.subject_public_key_info.len(); + let length = Self::BASE_LEN + self.subject_public_key_info.len(); let header = HeaderWithFlags::new( self.header.version(), @@ -750,13 +988,42 @@ impl RouterKey { ZERO_8, ]).await?; - w.write_all(&length.to_be_bytes()).await?; - // w.write_all(self.ski.as_ref()).await?; + w.write_all(&(length as u32).to_be_bytes()).await?; + w.write_all(self.ski.as_ref()).await?; w.write_all(&self.asn.into_u32().to_be_bytes()).await?; w.write_all(&self.subject_public_key_info).await?; Ok(()) } + pub fn new( + version: u8, + flags: Flags, + ski: Ski, + asn: Asn, + subject_public_key_info: Arc<[u8]>, + ) -> Self { + let length = Self::BASE_LEN + subject_public_key_info.len(); + + Self { + header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32), + flags, + ski, + asn, + subject_public_key_info, + } + } + + pub fn ski(&self) -> Ski { + self.ski + } + + pub fn asn(&self) -> Asn { + self.asn + } + + pub fn spki(&self) -> &[u8] { + &self.subject_public_key_info + } } @@ -772,21 +1039,19 @@ pub struct Aspa{ impl Aspa { pub const PDU: u8 = 11; + const BASE_LEN: usize = HEADER_LEN + 4; pub async fn write( &self, w: &mut A, ) -> Result<(), io::Error> { - let length = HEADER_LEN - + 1 - + 4 - + (self.provider_asns.len() * 4); + let length = Self::BASE_LEN + (self.provider_asns.len() * 4); let header = HeaderWithFlags::new( self.header.version(), Self::PDU, - Flags::new(self.header.flags), + self.header.flags(), length as u32, ); @@ -797,7 +1062,7 @@ impl Aspa { ZERO_8, ]).await?; - w.write_all(&length.to_be_bytes()).await?; + w.write_all(&(length as u32).to_be_bytes()).await?; w.write_all(&self.customer_asn.to_be_bytes()).await?; for asn in &self.provider_asns { @@ -806,6 +1071,20 @@ impl Aspa { Ok(()) } + pub fn new( + version: u8, + flags: Flags, + customer_asn: u32, + provider_asns: Vec, + ) -> Self { + let length = Self::BASE_LEN + (provider_asns.len() * 4); + + Self { + header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32), + customer_asn, + provider_asns, + } + } } @@ -823,54 +1102,3 @@ impl AsMut<[u8]> for ErrorReport { } } -#[cfg(test)] -mod tests { - use super::*; - use tokio::io::duplex; - - #[tokio::test] - async fn test_serial_notify_roundtrip() { - let (mut client, mut server) = duplex(1024); - - let original = SerialNotify::new(1, 42, 100); - - // 写入 - tokio::spawn(async move { - original.write(&mut client).await.unwrap(); - }); - - // 读取 - let decoded = SerialNotify::read(&mut server).await.unwrap(); - - assert_eq!(decoded.version(), 1); - assert_eq!(decoded.session_id(), 42); - assert_eq!(decoded.serial_number(), 100u32.to_be()); - } - - #[tokio::test] - async fn test_ipv4_prefix_roundtrip() { - use std::net::Ipv4Addr; - - let (mut client, mut server) = duplex(1024); - - let prefix = IPv4Prefix::new( - 1, - Flags::new(1), - 24, - 24, - Ipv4Addr::new(192,168,0,0), - 65000u32.into(), - ); - - tokio::spawn(async move { - prefix.write(&mut client).await.unwrap(); - }); - - let decoded = IPv4Prefix::read(&mut server).await.unwrap(); - - assert_eq!(decoded.prefix_len(), 24); - assert_eq!(decoded.max_len(), 24); - assert_eq!(decoded.prefix(), Ipv4Addr::new(192,168,0,0)); - assert_eq!(decoded.flag().is_announce(), true); - } -} diff --git a/src/rtr/server/config.rs b/src/rtr/server/config.rs index af46fe1..6920eb3 100644 --- a/src/rtr/server/config.rs +++ b/src/rtr/server/config.rs @@ -1,7 +1,12 @@ +use std::time::Duration; + #[derive(Debug, Clone)] pub struct RtrServiceConfig { pub max_connections: usize, pub notify_queue_size: usize, + pub tcp_keepalive: Option, + pub warn_insecure_tcp: bool, + pub require_tls_server_dns_name_san: bool, } impl Default for RtrServiceConfig { @@ -9,6 +14,9 @@ impl Default for RtrServiceConfig { Self { max_connections: 1024, notify_queue_size: 1024, + tcp_keepalive: Some(Duration::from_secs(60)), + warn_insecure_tcp: true, + require_tls_server_dns_name_san: false, } } -} \ No newline at end of file +} diff --git a/src/rtr/server/connection.rs b/src/rtr/server/connection.rs index 54870f4..d8e63d5 100644 --- a/src/rtr/server/connection.rs +++ b/src/rtr/server/connection.rs @@ -1,13 +1,15 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::{ Arc, atomic::{AtomicUsize, Ordering}, }; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, anyhow}; use tokio::net::TcpStream; use tokio::sync::{broadcast, watch, OwnedSemaphorePermit}; -use tracing::error; +use tracing::{error, info, warn}; +use x509_parser::extensions::GeneralName; +use x509_parser::prelude::{FromDer, X509Certificate}; use tokio_rustls::TlsAcceptor; @@ -30,6 +32,10 @@ impl ConnectionGuard { _permit: permit, } } + + pub fn active_count(&self) -> usize { + self.active_connections.load(Ordering::Relaxed) + } } impl Drop for ConnectionGuard { @@ -52,6 +58,7 @@ pub async fn handle_tcp_connection( return Err(err); } + info!("RTR TCP session completed normally for {}", peer_addr); Ok(()) } @@ -63,10 +70,15 @@ pub async fn handle_tls_connection( notify_rx: broadcast::Receiver<()>, shutdown_rx: watch::Receiver, ) -> Result<()> { + info!("RTR TLS handshake started for {}", peer_addr); let tls_stream = acceptor .accept(stream) .await .with_context(|| format!("TLS handshake failed for {}", peer_addr))?; + info!("RTR TLS handshake completed for {}", peer_addr); + verify_peer_certificate_ip(&tls_stream, peer_addr.ip()) + .with_context(|| format!("TLS client certificate SAN IP validation failed for {}", peer_addr))?; + info!("RTR TLS client certificate validated for {}", peer_addr); let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx); @@ -75,5 +87,57 @@ pub async fn handle_tls_connection( return Err(err); } + info!("RTR TLS session completed normally for {}", peer_addr); Ok(()) -} \ No newline at end of file +} + +fn verify_peer_certificate_ip( + tls_stream: &tokio_rustls::server::TlsStream, + peer_ip: IpAddr, +) -> Result<()> { + let (_, server_connection) = tls_stream.get_ref(); + let peer_certs = server_connection + .peer_certificates() + .ok_or_else(|| anyhow!("missing peer certificate after TLS client authentication"))?; + let end_entity = peer_certs + .first() + .ok_or_else(|| anyhow!("peer did not present an end-entity certificate"))?; + + let (_, cert) = X509Certificate::from_der(end_entity.as_ref()) + .map_err(|err| anyhow!("failed to parse peer certificate: {:?}", err))?; + let san = cert + .subject_alternative_name() + .map_err(|err| anyhow!("failed to parse peer certificate SAN: {:?}", err))? + .ok_or_else(|| anyhow!("peer certificate is missing subjectAltName"))?; + + let matched = san.value.general_names.iter().any(|name| match name { + GeneralName::IPAddress(bytes) => { + let bytes = *bytes; + match (peer_ip, bytes.len()) { + (IpAddr::V4(ip), 4) => <[u8; 4]>::try_from(bytes) + .map(IpAddr::from) + .map(|cert_ip| cert_ip == IpAddr::V4(ip)) + .unwrap_or(false), + (IpAddr::V6(ip), 16) => <[u8; 16]>::try_from(bytes) + .map(IpAddr::from) + .map(|cert_ip| cert_ip == IpAddr::V6(ip)) + .unwrap_or(false), + _ => false, + } + } + _ => false, + }); + + if matched { + Ok(()) + } else { + warn!( + "RTR TLS client certificate SAN IP mismatch for peer_ip={}", + peer_ip + ); + Err(anyhow!( + "peer certificate subjectAltName iPAddress does not match {}", + peer_ip + )) + } +} diff --git a/src/rtr/server/listener.rs b/src/rtr/server/listener.rs index 77c30dc..28fe940 100644 --- a/src/rtr/server/listener.rs +++ b/src/rtr/server/listener.rs @@ -4,8 +4,10 @@ use std::sync::{ Arc, atomic::AtomicUsize, }; +use std::time::Duration; use anyhow::{Context, Result}; +use socket2::{SockRef, TcpKeepalive}; use tokio::net::TcpListener; use tokio::sync::{broadcast, watch, Semaphore}; use tracing::{info, warn}; @@ -15,7 +17,8 @@ use tokio_rustls::TlsAcceptor; use crate::rtr::cache::SharedRtrCache; use crate::rtr::server::connection::{ConnectionGuard, handle_tcp_connection, handle_tls_connection}; -use crate::rtr::server::tls::load_rustls_server_config; +use crate::rtr::server::config::RtrServiceConfig; +use crate::rtr::server::tls::load_rustls_server_config_with_options; pub struct RtrServer { bind_addr: SocketAddr, @@ -24,6 +27,7 @@ pub struct RtrServer { shutdown_tx: watch::Sender, connection_limiter: Arc, active_connections: Arc, + config: RtrServiceConfig, } impl RtrServer { @@ -34,6 +38,7 @@ impl RtrServer { shutdown_tx: watch::Sender, connection_limiter: Arc, active_connections: Arc, + config: RtrServiceConfig, ) -> Self { Self { bind_addr, @@ -42,6 +47,7 @@ impl RtrServer { shutdown_tx, connection_limiter, active_connections, + config, } } @@ -95,6 +101,9 @@ impl RtrServer { if let Err(err) = stream.set_nodelay(true) { warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err); } + if let Err(err) = apply_keepalive(&stream, self.config.tcp_keepalive) { + warn!("failed to configure TCP keepalive for {}: {}", peer_addr, err); + } let permit = match self.connection_limiter.clone().try_acquire_owned() { Ok(permit) => permit, @@ -102,7 +111,7 @@ impl RtrServer { warn!( "RTR TCP connection rejected for {}: max connections reached ({})", peer_addr, - self.connection_limiter.available_permits() + self.config.max_connections ); drop(stream); continue; @@ -114,16 +123,34 @@ impl RtrServer { let shutdown_rx = self.shutdown_tx.subscribe(); let active_connections = self.active_connections.clone(); - info!("RTR TCP client connected: {}", peer_addr); + info!( + "RTR TCP client connected: peer_addr={}, active_connections(before_spawn)={}", + peer_addr, + self.active_connections() + ); tokio::spawn(async move { - let _guard = ConnectionGuard::new(active_connections, permit); + let guard = ConnectionGuard::new(active_connections, permit); + info!( + "RTR TCP connection established: peer_addr={}, active_connections={}", + peer_addr, + guard.active_count() + ); if let Err(err) = handle_tcp_connection(cache, stream, peer_addr, notify_rx, shutdown_rx).await { - warn!("RTR TCP session {} ended with error: {:?}", peer_addr, err); + warn!( + "RTR TCP session closed with error: peer_addr={}, active_connections={}, err={:?}", + peer_addr, + guard.active_count(), + err + ); } else { - info!("RTR TCP session {} closed", peer_addr); + info!( + "RTR TCP session closed cleanly: peer_addr={}, active_connections={}", + peer_addr, + guard.active_count() + ); } }); } @@ -135,8 +162,14 @@ impl RtrServer { self, cert_path: impl AsRef, key_path: impl AsRef, + client_ca_path: impl AsRef, ) -> Result<()> { - let tls_config = Arc::new(load_rustls_server_config(cert_path, key_path)?); + let tls_config = Arc::new(load_rustls_server_config_with_options( + cert_path, + key_path, + client_ca_path, + self.config.require_tls_server_dns_name_san, + )?); self.run_tls(tls_config).await } @@ -179,11 +212,18 @@ impl RtrServer { if let Err(err) = stream.set_nodelay(true) { warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err); } + if let Err(err) = apply_keepalive(&stream, self.config.tcp_keepalive) { + warn!("failed to configure TCP keepalive for {}: {}", peer_addr, err); + } let permit = match self.connection_limiter.clone().try_acquire_owned() { Ok(permit) => permit, Err(_) => { - warn!("RTR TLS connection rejected for {}: max connections reached", peer_addr); + warn!( + "RTR TLS connection rejected for {}: max connections reached ({})", + peer_addr, + self.config.max_connections + ); drop(stream); continue; } @@ -195,10 +235,19 @@ impl RtrServer { let shutdown_rx = self.shutdown_tx.subscribe(); let active_connections = self.active_connections.clone(); - info!("RTR TLS client connected: {}", peer_addr); + info!( + "RTR TLS client connected: peer_addr={}, active_connections(before_spawn)={}", + peer_addr, + self.active_connections() + ); tokio::spawn(async move { - let _guard = ConnectionGuard::new(active_connections, permit); + let guard = ConnectionGuard::new(active_connections, permit); + info!( + "RTR TLS connection established: peer_addr={}, active_connections={}", + peer_addr, + guard.active_count() + ); if let Err(err) = handle_tls_connection( cache, stream, @@ -207,13 +256,38 @@ impl RtrServer { notify_rx, shutdown_rx, ).await { - warn!("RTR TLS session {} ended with error: {:?}", peer_addr, err); + warn!( + "RTR TLS session closed with error: peer_addr={}, active_connections={}, err={:?}", + peer_addr, + guard.active_count(), + err + ); } else { - info!("RTR TLS session {} closed", peer_addr); + info!( + "RTR TLS session closed cleanly: peer_addr={}, active_connections={}", + peer_addr, + guard.active_count() + ); } }); } } } } -} \ No newline at end of file +} + +fn apply_keepalive( + stream: &tokio::net::TcpStream, + keepalive: Option, +) -> Result<()> { + let Some(keepalive) = keepalive else { + return Ok(()); + }; + + let socket = SockRef::from(stream); + let keepalive = TcpKeepalive::new().with_time(keepalive); + socket + .set_tcp_keepalive(&keepalive) + .context("unable to apply TCP keepalive settings")?; + Ok(()) +} diff --git a/src/rtr/server/service.rs b/src/rtr/server/service.rs index 275a94a..cdb7f53 100644 --- a/src/rtr/server/service.rs +++ b/src/rtr/server/service.rs @@ -7,7 +7,7 @@ use std::sync::{ use tokio::sync::{broadcast, watch, Semaphore}; use tokio::task::JoinHandle; -use tracing::error; +use tracing::{error, warn}; use crate::rtr::cache::SharedRtrCache; use crate::rtr::server::config::RtrServiceConfig; @@ -70,6 +70,7 @@ impl RtrService { self.shutdown_tx.clone(), self.connection_limiter.clone(), self.active_connections.clone(), + self.config.clone(), ) } @@ -81,10 +82,17 @@ impl RtrService { self.shutdown_tx.clone(), self.connection_limiter.clone(), self.active_connections.clone(), + self.config.clone(), ) } pub fn spawn_tcp(&self, bind_addr: SocketAddr) -> JoinHandle<()> { + if self.config.warn_insecure_tcp { + warn!( + "starting plain TCP RTR service on {}. Per draft-ietf-sidrops-8210bis-25 Section 9, unsecured TCP must only be used on a trusted and controlled network", + bind_addr + ); + } let server = self.tcp_server(bind_addr); tokio::spawn(async move { if let Err(err) = server.run_tcp().await { @@ -98,13 +106,15 @@ impl RtrService { bind_addr: SocketAddr, cert_path: impl AsRef, key_path: impl AsRef, + client_ca_path: impl AsRef, ) -> JoinHandle<()> { let cert_path = cert_path.as_ref().to_path_buf(); let key_path = key_path.as_ref().to_path_buf(); + let client_ca_path = client_ca_path.as_ref().to_path_buf(); let server = self.tls_server(bind_addr); tokio::spawn(async move { - if let Err(err) = server.run_tls_from_pem(cert_path, key_path).await { + if let Err(err) = server.run_tls_from_pem(cert_path, key_path, client_ca_path).await { error!("RTR TLS server {} exited with error: {:?}", bind_addr, err); } }) @@ -116,9 +126,10 @@ impl RtrService { tls_bind_addr: SocketAddr, cert_path: impl AsRef, key_path: impl AsRef, + client_ca_path: impl AsRef, ) -> RunningRtrService { let tcp_handle = self.spawn_tcp(tcp_bind_addr); - let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path); + let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path, client_ca_path); RunningRtrService { shutdown_tx: self.shutdown_tx.clone(), @@ -151,4 +162,4 @@ impl RunningRtrService { let _ = handle.await; } } -} \ No newline at end of file +} diff --git a/src/rtr/server/tls.rs b/src/rtr/server/tls.rs index 6e178bc..ca7f02c 100644 --- a/src/rtr/server/tls.rs +++ b/src/rtr/server/tls.rs @@ -1,32 +1,101 @@ use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; +use std::sync::Arc; use anyhow::{anyhow, Context, Result}; -use rustls::ServerConfig; +use rustls::server::WebPkiClientVerifier; +use rustls::{RootCertStore, ServerConfig}; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; +use tracing::warn; +use x509_parser::extensions::GeneralName; +use x509_parser::prelude::{FromDer, X509Certificate}; pub fn load_rustls_server_config( cert_path: impl AsRef, key_path: impl AsRef, + client_ca_path: impl AsRef, +) -> Result { + load_rustls_server_config_with_options(cert_path, key_path, client_ca_path, false) +} + +pub fn load_rustls_server_config_with_options( + cert_path: impl AsRef, + key_path: impl AsRef, + client_ca_path: impl AsRef, + require_dns_name_san: bool, ) -> Result { let cert_path: PathBuf = cert_path.as_ref().to_path_buf(); let key_path: PathBuf = key_path.as_ref().to_path_buf(); + let client_ca_path: PathBuf = client_ca_path.as_ref().to_path_buf(); let certs = load_certs(&cert_path) .with_context(|| format!("failed to load certs from {}", cert_path.display()))?; + validate_server_certificate_dns_name_san(&certs, &cert_path, require_dns_name_san)?; let key = load_private_key(&key_path) .with_context(|| format!("failed to load private key from {}", key_path.display()))?; + let client_ca_certs = load_certs(&client_ca_path) + .with_context(|| format!("failed to load client CA certs from {}", client_ca_path.display()))?; + let mut client_roots = RootCertStore::empty(); + let (added, _) = client_roots.add_parsable_certificates(client_ca_certs); + if added == 0 { + return Err(anyhow!( + "no valid client CA certificates found in {}", + client_ca_path.display() + )); + } + let client_verifier = WebPkiClientVerifier::builder(Arc::new(client_roots)) + .build() + .map_err(|e| anyhow!("invalid client certificate verifier configuration: {}", e))?; + let config = ServerConfig::builder() - .with_no_client_auth() + .with_client_cert_verifier(client_verifier) .with_single_cert(certs, key) .map_err(|e| anyhow!("invalid certificate/key pair: {}", e))?; Ok(config) } +fn validate_server_certificate_dns_name_san( + certs: &[CertificateDer<'static>], + cert_path: &Path, + require_dns_name_san: bool, +) -> Result<()> { + let leaf = certs + .first() + .ok_or_else(|| anyhow!("missing end-entity certificate in {}", cert_path.display()))?; + let (_, cert) = X509Certificate::from_der(leaf.as_ref()) + .map_err(|err| anyhow!("failed to parse server certificate: {:?}", err))?; + let has_dns_name_san = cert + .subject_alternative_name() + .map_err(|err| anyhow!("failed to parse server certificate SAN: {:?}", err))? + .map(|san| { + san.value + .general_names + .iter() + .any(|name| matches!(name, GeneralName::DNSName(_))) + }) + .unwrap_or(false); + + if has_dns_name_san { + return Ok(()); + } + + let message = format!( + "server certificate {} does not contain a subjectAltName dNSName entry; draft-ietf-sidrops-8210bis-25 Section 9.2 requires routers to authenticate the cache using DNS-ID rather than CN-ID", + cert_path.display() + ); + + if require_dns_name_san { + Err(anyhow!(message)) + } else { + warn!("{}", message); + Ok(()) + } +} + fn load_certs(path: &Path) -> Result>> { let file = File::open(path)?; let mut reader = BufReader::new(file); @@ -49,4 +118,4 @@ fn load_private_key(path: &Path) -> Result> { .ok_or_else(|| anyhow!("no private key found in {}", path.display()))?; Ok(key) -} \ No newline at end of file +} diff --git a/src/rtr/session.rs b/src/rtr/session.rs index 500b42b..ff56c31 100644 --- a/src/rtr/session.rs +++ b/src/rtr/session.rs @@ -1,17 +1,26 @@ -use anyhow::{bail, Result}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::{anyhow, bail, Result}; use tokio::io; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::sync::{broadcast, watch}; -use tracing::warn; +use tokio::time::timeout; +use tracing::{debug, error, info, warn}; use crate::data_model::resources::ip_resources::IPAddress; -use crate::rtr::cache::{Delta, SerialResult, SharedRtrCache}; -use crate::rtr::error_type::ErrorCode; -use crate::rtr::payload::{Payload, RouteOrigin}; -use crate::rtr::pdu::{ - CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, Header, IPv4Prefix, IPv6Prefix, - ResetQuery, SerialNotify, SerialQuery, +use crate::rtr::cache::{ + Delta, SerialResult, SharedRtrCache, validate_payload_updates_for_rtr, + validate_payloads_for_rtr, }; +use crate::rtr::error_type::ErrorCode; +use crate::rtr::pdu::{ + Aspa as AspaPdu, + CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, Header, IPv4Prefix, IPv6Prefix, + ResetQuery, RouterKey as RouterKeyPdu, SerialNotify, SerialQuery, + HEADER_LEN, +}; +use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey}; const SUPPORTED_MAX_VERSION: u8 = 2; const SUPPORTED_MIN_VERSION: u8 = 0; @@ -19,6 +28,9 @@ const SUPPORTED_MIN_VERSION: u8 = 0; const ANNOUNCE_FLAG: u8 = 1; const WITHDRAW_FLAG: u8 = 0; +/// Per-session notify rate limit: no more than once per minute. +const NOTIFY_MIN_INTERVAL: Duration = Duration::from_secs(60); + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SessionState { Connected, @@ -33,6 +45,8 @@ pub struct RtrSession { state: SessionState, notify_rx: broadcast::Receiver<()>, shutdown_rx: watch::Receiver, + last_notify_at: Option, + transport_timeout_override: Option, } impl RtrSession @@ -52,207 +66,720 @@ where state: SessionState::Connected, notify_rx, shutdown_rx, + last_notify_at: None, + transport_timeout_override: None, } } + pub fn with_transport_timeout(mut self, timeout: Duration) -> Self { + self.transport_timeout_override = Some(timeout); + self + } + pub async fn run(mut self) -> Result<()> { + info!( + "RTR session started: {}", + self.session_summary() + ); loop { + let transport_timeout = self.transport_timeout(); tokio::select! { - changed = self.shutdown_rx.changed() => { - match changed { - Ok(()) => { - if *self.shutdown_rx.borrow() { + changed = self.shutdown_rx.changed() => { + match changed { + Ok(()) => { + if *self.shutdown_rx.borrow() { + info!("RTR session closing due to service shutdown: {}", self.session_summary()); + self.state = SessionState::Closed; + return Ok(()); + } + } + Err(_) => { + info!("RTR session closing because shutdown channel closed: {}", self.session_summary()); self.state = SessionState::Closed; return Ok(()); } } - Err(_) => { - // shutdown sender dropped,按关闭处理 + } + + header_res = timeout(transport_timeout, Header::read_raw(&mut self.stream)) => { + let raw_header = match header_res { + Ok(Ok(raw)) => raw, + Ok(Err(_)) => { + info!("RTR session closed by peer before header read completed: {}", self.session_summary()); + self.state = SessionState::Closed; + return Ok(()); + } + Err(_) => { + warn!("RTR session transport timeout while waiting for header: {}", self.session_summary()); + self.handle_transport_timeout(&[]).await?; + return Ok(()); + } + }; + let header = match Header::from_raw(raw_header) { + Ok(h) => h, + Err(err) => { + warn!( + "RTR session received invalid header: err={}, {}", + err, + self.session_summary() + ); + self.handle_header_read_error(raw_header, err).await?; + return Ok(()); + } + }; + debug!( + "RTR session received header: pdu={}, version={}, length={}, state={}", + header.pdu(), + header.version(), + header.length(), + self.state_name() + ); + let header_prefix = header.as_ref(); + + if self.state == SessionState::Connected { + self.handle_first_pdu(header).await?; + if self.state == SessionState::Closed { + info!("RTR session closed during first-PDU handling: {}", self.session_summary()); + return Ok(()); + } + continue; + } + + if self.state != SessionState::Established { self.state = SessionState::Closed; - return Ok(()); + error!("RTR session entered invalid state: {}", self.session_summary()); + bail!("session is not in a valid state to receive PDUs"); + } + + self.ensure_established_version(header).await?; + + match header.pdu() { + ResetQuery::PDU => { + let query = match timeout( + self.transport_timeout(), + ResetQuery::read_payload(header, &mut self.stream), + ) + .await + { + Ok(Ok(query)) => query, + Ok(Err(err)) => { + self.handle_pdu_read_error(header, err).await?; + return Ok(()); + } + Err(_) => { + self.handle_transport_timeout(header_prefix).await?; + return Ok(()); + } + }; + self.handle_reset_query(query.as_ref()).await?; + } + + SerialQuery::PDU => { + let query = match timeout( + self.transport_timeout(), + SerialQuery::read_payload(header, &mut self.stream), + ) + .await + { + Ok(Ok(query)) => query, + Ok(Err(err)) => { + self.handle_pdu_read_error(header, err).await?; + return Ok(()); + } + Err(_) => { + self.handle_transport_timeout(header_prefix).await?; + return Ok(()); + } + }; + let session_id = query.session_id(); + let serial = query.serial_number(); + self.handle_serial(header.version(), session_id, serial, query.as_ref()).await?; + } + + ErrorReport::PDU => { + let _ = timeout( + self.transport_timeout(), + ErrorReport::skip_payload(header, &mut self.stream), + ) + .await; + info!("RTR session received ErrorReport from peer, closing session: {}", self.session_summary()); + self.state = SessionState::Closed; + return Ok(()); + } + + _ => { + let offending = self.read_full_pdu_bytes(header).await?; + let version = self.version.ok_or_else(|| anyhow!("missing negotiated version"))?; + + self.send_error( + version, + ErrorCode::UnsupportedPduType, + &offending, + &[], + ) + .await?; + + warn!( + "RTR session received unsupported PDU type {}, closing session: {}", + header.pdu(), + self.session_summary() + ); + self.state = SessionState::Closed; + return Ok(()); + } + } + } + + notify_res = self.notify_rx.recv(), + if self.state == SessionState::Established && self.version.is_some() => { + match notify_res { + Ok(()) => { + debug!("RTR session handling cache update notify: {}", self.session_summary()); + self.handle_notify().await?; + } + Err(broadcast::error::RecvError::Lagged(_)) => { + warn!("RTR session lagged on notify channel, forcing notify handling: {}", self.session_summary()); + self.handle_notify().await?; + } + Err(broadcast::error::RecvError::Closed) => { + debug!("RTR session notify channel closed, keeping session alive: {}", self.session_summary()); + // keep session alive + } } } } - - header_res = Header::read(&mut self.stream) => { - let header = match header_res { - Ok(h) => h, - Err(_) => { - self.state = SessionState::Closed; - return Ok(()); - } - }; - - if self.version.is_none() { - self.negotiate_version(header.version()).await?; - } else if header.version() != self.version.unwrap() { - self.send_unsupported_version(self.version.unwrap()).await?; - self.state = SessionState::Closed; - bail!("version changed within session"); - } - - match header.pdu() { - ResetQuery::PDU => { - let _ = ResetQuery::read_payload(header, &mut self.stream).await?; - self.handle_reset_query().await?; - } - SerialQuery::PDU => { - let query = SerialQuery::read_payload(header, &mut self.stream).await?; - let session_id = query.session_id(); - let serial = u32::from_be(query.serial_number()); - self.handle_serial(session_id, serial).await?; - } - ErrorReport::PDU => { - let _ = ErrorReport::skip_payload(header, &mut self.stream).await; - self.state = SessionState::Closed; - return Ok(()); - } - _ => { - self.send_error( - header.version(), - ErrorCode::UnsupportedPduType, - Some(&header), - &[], - ) - .await?; - self.state = SessionState::Closed; - return Ok(()); - } - } - } - - notify_res = self.notify_rx.recv(), - if self.state == SessionState::Established && self.version.is_some() => { - match notify_res { - Ok(()) => { - self.handle_notify().await?; - } - Err(broadcast::error::RecvError::Lagged(_)) => { - self.handle_notify().await?; - } - Err(broadcast::error::RecvError::Closed) => { - // notify 通道关闭,不影响已有会话,继续跑,真正关闭由 shutdown_rx 控制 - } - } - } - } } } async fn negotiate_version(&mut self, router_version: u8) -> io::Result { - if router_version < SUPPORTED_MIN_VERSION { - self.send_unsupported_version(SUPPORTED_MIN_VERSION).await?; + if let Some(current) = self.version { + if current == router_version { + return Ok(current); + } + return Err(io::Error::new( io::ErrorKind::InvalidData, - "unsupported lower protocol version", + format!( + "protocol version changed after negotiation: established={}, received={}", + current, router_version + ), )); } if router_version > SUPPORTED_MAX_VERSION { - self.send_unsupported_version(SUPPORTED_MAX_VERSION).await?; return Err(io::Error::new( io::ErrorKind::InvalidData, - "router version higher than cache", + format!( + "router version {} higher than cache max {}", + router_version, SUPPORTED_MAX_VERSION + ), )); } self.version = Some(router_version); + info!( + "RTR session negotiated protocol version: negotiated_version={}, state={}", + router_version, + self.state_name() + ); Ok(router_version) } - async fn send_unsupported_version(&mut self, cache_version: u8) -> io::Result<()> { - ErrorReport::new( + async fn ensure_established_version(&mut self, header: Header) -> Result<()> { + let established = self + .version + .ok_or_else(|| anyhow!("session version not negotiated"))?; + + if header.version() == established { + return Ok(()); + } + + warn!( + "RTR session received unexpected protocol version in established session: established_version={}, received_version={}, pdu={}", + established, + header.version(), + header.pdu() + ); + + if header.pdu() != ErrorReport::PDU { + let offending = self.read_full_pdu_bytes(header).await?; + let _ = self + .send_unexpected_version(established, &offending, header.version()) + .await; + } + + self.state = SessionState::Closed; + bail!( + "protocol version changed within session: established={}, received={}", + established, + header.version() + ); + } + + async fn send_unsupported_version( + &mut self, + cache_version: u8, + offending_pdu: &[u8], + received_version: u8, + ) -> io::Result<()> { + let msg = format!( + "unsupported protocol version {}, cache supports versions {}..={}", + received_version, SUPPORTED_MIN_VERSION, SUPPORTED_MAX_VERSION + ); + + self.send_error( cache_version, - ErrorCode::UnsupportedProtocolVersion.as_u16(), - &[], - ErrorCode::UnsupportedProtocolVersion.description(), + ErrorCode::UnsupportedProtocolVersion, + offending_pdu, + msg.as_bytes(), ) - .write(&mut self.stream) .await } - async fn handle_reset_query(&mut self) -> Result<()> { - let (payloads, session_id, serial) = { + async fn send_unexpected_version( + &mut self, + established_version: u8, + offending_pdu: &[u8], + received_version: u8, + ) -> io::Result<()> { + let msg = format!( + "unexpected protocol version {}, established version is {}", + received_version, + established_version + ); + + self.send_error( + established_version, + ErrorCode::UnexpectedProtocolVersion, + offending_pdu, + msg.as_bytes(), + ) + .await + } + + async fn send_corrupt_session_id( + &mut self, + offending_pdu: &[u8], + received_session_id: u16, + expected_session_id: u16, + ) -> io::Result<()> { + let version = self.version()?; + let msg = format!( + "session id mismatch in established session: received={}, expected={}", + received_session_id, expected_session_id + ); + + self.send_error( + version, + ErrorCode::CorruptData, + offending_pdu, + msg.as_bytes(), + ) + .await + } + + async fn send_corrupt_data( + &mut self, + version: u8, + offending_pdu: &[u8], + detail: &[u8], + ) -> io::Result<()> { + self.send_error( + version, + ErrorCode::CorruptData, + offending_pdu, + detail, + ) + .await + } + + async fn send_no_data_available( + &mut self, + offending_pdu: &[u8], + detail: &'static str, + ) -> io::Result<()> { + let version = self.version()?; + self.send_error( + version, + ErrorCode::NoDataAvailable, + offending_pdu, + detail.as_bytes(), + ) + .await + } + + async fn handle_first_pdu(&mut self, header: Header) -> Result<()> { + info!( + "RTR session processing first PDU: pdu={}, version={}, length={}", + header.pdu(), + header.version(), + header.length() + ); + match header.pdu() { + ResetQuery::PDU => { + let version = header.version(); + let query = match timeout( + self.transport_timeout(), + ResetQuery::read_payload(header, &mut self.stream), + ) + .await + { + Ok(Ok(query)) => query, + Ok(Err(err)) => { + self.handle_first_pdu_read_error(header, err).await?; + return Ok(()); + } + Err(_) => { + self.handle_transport_timeout(header.as_ref()).await?; + return Ok(()); + } + }; + if version > SUPPORTED_MAX_VERSION { + self.send_unsupported_version(SUPPORTED_MAX_VERSION, query.as_ref(), version) + .await?; + self.state = SessionState::Closed; + bail!( + "router version {} higher than cache max {}", + version, SUPPORTED_MAX_VERSION + ); + } + self.negotiate_version(version).await?; + self.handle_reset_query(query.as_ref()).await?; + self.state = SessionState::Established; + info!( + "RTR session established after Reset Query: negotiated_version={}, {}", + version, + self.session_summary() + ); + } + + SerialQuery::PDU => { + let version = header.version(); + let query = match timeout( + self.transport_timeout(), + SerialQuery::read_payload(header, &mut self.stream), + ) + .await + { + Ok(Ok(query)) => query, + Ok(Err(err)) => { + self.handle_first_pdu_read_error(header, err).await?; + return Ok(()); + } + Err(_) => { + self.handle_transport_timeout(header.as_ref()).await?; + return Ok(()); + } + }; + if version > SUPPORTED_MAX_VERSION { + self.send_unsupported_version(SUPPORTED_MAX_VERSION, query.as_ref(), version) + .await?; + self.state = SessionState::Closed; + bail!( + "router version {} higher than cache max {}", + version, SUPPORTED_MAX_VERSION + ); + } + self.negotiate_version(version).await?; + + let session_id = query.session_id(); + let serial = query.serial_number(); + self.handle_serial(version, session_id, serial, query.as_ref()).await?; + self.state = SessionState::Established; + info!( + "RTR session established after Serial Query: negotiated_version={}, client_session_id={}, client_serial={}, {}", + version, + session_id, + serial, + self.session_summary() + ); + } + + ErrorReport::PDU => { + let _ = timeout( + self.transport_timeout(), + ErrorReport::skip_payload(header, &mut self.stream), + ) + .await; + self.state = SessionState::Closed; + bail!("received ErrorReport before session establishment"); + } + + _ => { + let offending = self.read_full_pdu_bytes(header).await?; + let err_version = if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION) + .contains(&header.version()) + { + header.version() + } else { + SUPPORTED_MAX_VERSION + }; + + let _ = self + .send_error( + err_version, + ErrorCode::InvalidRequest, + &offending, + b"first PDU must be Reset Query or Serial Query", + ) + .await; + + self.state = SessionState::Closed; + bail!("first PDU must be Reset Query or Serial Query"); + } + } + + Ok(()) + } + + async fn handle_reset_query(&mut self, offending_pdu: &[u8]) -> Result<()> { + info!( + "RTR session received Reset Query: negotiated_version={:?}, offending_pdu_len={}", + self.version, + offending_pdu.len() + ); + let (data_available, payloads, session_id, serial) = { + let version = self.version()?; let cache = self .cache .read() - .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + .map_err(|_| anyhow!("cache read lock poisoned"))?; + let data_available = cache.is_data_available(); let snapshot = cache.snapshot(); let payloads = snapshot.payloads_for_rtr(); - let session_id = cache.session_id(); + let session_id = cache.session_id_for_version(version); let serial = cache.serial(); - (payloads, session_id, serial) + (data_available, payloads, session_id, serial) }; + if !data_available { + self.send_no_data_available(offending_pdu, "cache data is not currently available") + .await?; + info!( + "RTR session replied No Data Available to Reset Query: {}", + self.session_summary() + ); + return Ok(()); + } + self.write_cache_response(session_id).await?; self.send_payloads(&payloads, true).await?; self.write_end_of_data(session_id, serial).await?; + info!( + "RTR session completed Reset Query: response_session_id={}, response_serial={}, payload_count={}, {}", + session_id, + serial, + payloads.len(), + self.session_summary() + ); - self.state = SessionState::Established; Ok(()) } - async fn handle_serial(&mut self, client_session: u16, client_serial: u32) -> Result<()> { + async fn handle_serial( + &mut self, + version: u8, + client_session: u16, + client_serial: u32, + offending_pdu: &[u8], + ) -> Result<()> { + info!( + "RTR session received Serial Query: negotiated_version={}, client_session_id={}, client_serial={}, offending_pdu_len={}", + version, + client_session, + client_serial, + offending_pdu.len() + ); + let (data_available, current_session) = { + let cache = self + .cache + .read() + .map_err(|_| anyhow!("cache read lock poisoned"))?; + (cache.is_data_available(), cache.session_id_for_version(version)) + }; + + if !data_available { + self.send_no_data_available(offending_pdu, "cache data is not currently available") + .await?; + info!( + "RTR session replied No Data Available to Serial Query: client_session_id={}, client_serial={}, {}", + client_session, + client_serial, + self.session_summary() + ); + return Ok(()); + } + + // Strict 8210bis behavior: + // within an established transport session, a Session ID mismatch is fatal. + if client_session != current_session { + self.send_corrupt_session_id(offending_pdu, client_session, current_session) + .await?; + self.state = SessionState::Closed; + bail!( + "session id mismatch in established session: received={}, expected={}", + client_session, + current_session + ); + } + let serial_result = { let cache = self .cache .read() - .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; - cache.get_deltas_since(client_session, client_serial) + .map_err(|_| anyhow!("cache read lock poisoned"))?; + cache.get_deltas_since(client_serial) }; match serial_result { SerialResult::ResetRequired => { self.write_cache_reset().await?; - self.state = SessionState::Established; + info!( + "RTR session replied Cache Reset to Serial Query: client_session_id={}, client_serial={}, {}", + client_session, + client_serial, + self.session_summary() + ); return Ok(()); } + SerialResult::UpToDate => { let (current_session, current_serial) = { let cache = self .cache .read() - .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; - (cache.session_id(), cache.serial()) + .map_err(|_| anyhow!("cache read lock poisoned"))?; + (cache.session_id_for_version(version), cache.serial()) }; self.write_end_of_data(current_session, current_serial) .await?; - self.state = SessionState::Established; + info!( + "RTR session replied EndOfData (up-to-date) to Serial Query: client_session_id={}, client_serial={}, response_session_id={}, response_serial={}, {}", + client_session, + client_serial, + current_session, + current_serial, + self.session_summary() + ); return Ok(()); } - SerialResult::Deltas(deltas) => { + + SerialResult::Delta(delta) => { let (current_session, current_serial) = { let cache = self .cache .read() - .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; - (cache.session_id(), cache.serial()) + .map_err(|_| anyhow!("cache read lock poisoned"))?; + (cache.session_id_for_version(version), cache.serial()) }; self.write_cache_response(current_session).await?; - for delta in deltas { - self.send_delta(&delta).await?; - } + self.send_delta(&delta).await?; self.write_end_of_data(current_session, current_serial) .await?; + info!( + "RTR session replied delta to Serial Query: client_session_id={}, client_serial={}, response_session_id={}, response_serial={}, {}", + client_session, + client_serial, + current_session, + current_serial, + self.session_summary() + ); } } - self.state = SessionState::Established; Ok(()) } async fn handle_notify(&mut self) -> Result<()> { + if self.state != SessionState::Established { + return Ok(()); + } + + let version = match self.version { + Some(v) => v, + None => return Ok(()), + }; + + let now = Instant::now(); + if let Some(last) = self.last_notify_at { + if now.duration_since(last) < NOTIFY_MIN_INTERVAL { + debug!("RTR session notify skipped due to rate limit: {}", self.session_summary()); + return Ok(()); + } + } + let (session_id, serial) = { let cache = self .cache .read() - .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; - (cache.session_id(), cache.serial()) + .map_err(|_| anyhow!("cache read lock poisoned"))?; + (cache.session_id_for_version(version), cache.serial()) }; - self.send_serial_notify(session_id, serial).await + debug!( + "RTR session sending SerialNotify: notify_session_id={}, notify_serial={}, {}", + session_id, + serial, + self.session_summary() + ); + if let Err(err) = SerialNotify::new(version, session_id, serial) + .write(&mut self.stream) + .await + { + error!( + "RTR session failed to send SerialNotify: err={}, notify_session_id={}, notify_serial={}, {}", + err, + session_id, + serial, + self.session_summary() + ); + return Err(err.into()); + } + + self.last_notify_at = Some(now); + info!( + "RTR session sent SerialNotify: notify_session_id={}, notify_serial={}, {}", + session_id, + serial, + self.session_summary() + ); + Ok(()) + } + + fn session_summary(&self) -> String { + let serial = self + .cache + .read() + .ok() + .map(|cache| cache.serial().to_string()) + .unwrap_or_else(|| "".to_string()); + let session_id = self + .version + .and_then(|version| { + self.cache + .read() + .ok() + .map(|cache| cache.session_id_for_version(version)) + }) + .map(|id| id.to_string()) + .unwrap_or_else(|| "".to_string()); + format!( + "state={}, negotiated_version={:?}, cache_session_id={}, cache_serial={}", + self.state_name(), + self.version, + session_id, + serial + ) + } + + fn state_name(&self) -> &'static str { + match self.state { + SessionState::Connected => "Connected", + SessionState::Established => "Established", + SessionState::Closed => "Closed", + } } fn version(&self) -> io::Result { @@ -260,16 +787,28 @@ where .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "version not negotiated")) } - async fn send_serial_notify(&mut self, session_id: u16, serial: u32) -> Result<()> { - let version = self.version()?; - SerialNotify::new(version, session_id, serial) - .write(&mut self.stream) - .await?; - Ok(()) + fn transport_timeout(&self) -> Duration { + if let Some(timeout) = self.transport_timeout_override { + return timeout; + } + + let retry = self + .cache + .read() + .ok() + .map(|cache| cache.timing().retry()) + .unwrap_or_else(|| Duration::from_secs(600)); + + retry.checked_mul(3).unwrap_or(retry) } async fn write_cache_response(&mut self, session_id: u16) -> Result<()> { let version = self.version()?; + debug!( + "RTR session writing Cache Response: version={}, session_id={}", + version, + session_id + ); CacheResponse::new(version, session_id) .write(&mut self.stream) .await?; @@ -278,6 +817,10 @@ where async fn write_cache_reset(&mut self) -> Result<()> { let version = self.version()?; + info!( + "RTR session writing Cache Reset: version={}", + version + ); CacheReset::new(version).write(&mut self.stream).await?; Ok(()) } @@ -288,11 +831,20 @@ where let cache = self .cache .read() - .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + .map_err(|_| anyhow!("cache read lock poisoned"))?; cache.timing() }; - let end = EndOfData::new(version, session_id, serial, timing); + let end = EndOfData::new(version, session_id, serial, timing)?; + debug!( + "RTR session writing EndOfData: version={}, session_id={}, serial={}, refresh={}s, retry={}s, expire={}s", + version, + session_id, + serial, + timing.refresh().as_secs(), + timing.retry().as_secs(), + timing.expire().as_secs() + ); match end { EndOfData::V0(pdu) => pdu.write(&mut self.stream).await?, EndOfData::V1(pdu) => pdu.write(&mut self.stream).await?, @@ -302,6 +854,25 @@ where } async fn send_payloads(&mut self, payloads: &[Payload], announce: bool) -> Result<()> { + // draft-ietf-sidrops-8210bis-25 Section 11.4 / 12 define Ordering Error + // as a receiver-side response to PDUs received in the wrong order. + // When we detect an ordering issue before sending, this is a local cache/ + // implementation fault rather than a received protocol error, so we fail + // the session internally instead of emitting ErrorReport(code=11). + // References: + // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 + // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 + validate_payloads_for_rtr(payloads, announce) + .map_err(|err| anyhow!(err.to_string()))?; + let (route_origins, router_keys, aspas) = count_payloads(payloads); + debug!( + "RTR session sending snapshot payloads: announce={}, total={}, route_origins={}, router_keys={}, aspas={}", + announce, + payloads.len(), + route_origins, + router_keys, + aspas + ); for payload in payloads { self.send_payload(payload, announce).await?; } @@ -309,25 +880,48 @@ where } async fn send_delta(&mut self, delta: &Delta) -> Result<()> { - for payload in delta.announced() { - self.send_payload(payload, true).await?; - } - for payload in delta.withdrawn() { - self.send_payload(payload, false).await?; + let updates = delta.payloads_for_rtr(); + // draft-ietf-sidrops-8210bis-25 Section 11.4 / 12 define Ordering Error + // for the party receiving out-of-order PDUs. A validator failure here + // means we are about to send an invalid sequence, so abort locally + // instead of reporting Ordering Error back to the router. + // References: + // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 + // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 + validate_payload_updates_for_rtr(&updates) + .map_err(|err| anyhow!(err.to_string()))?; + let (announced, withdrawn, route_origins, router_keys, aspas) = + count_payload_updates(&updates); + debug!( + "RTR session sending delta payloads: total={}, announced={}, withdrawn={}, route_origins={}, router_keys={}, aspas={}", + updates.len(), + announced, + withdrawn, + route_origins, + router_keys, + aspas + ); + for (announce, payload) in updates { + self.send_payload(&payload, announce).await?; } Ok(()) } async fn send_payload(&mut self, payload: &Payload, announce: bool) -> Result<()> { + let version = self.version()?; match payload { Payload::RouteOrigin(origin) => { self.send_route_origin(origin, announce).await?; } - Payload::RouterKey(_) => { - warn!("router key payload not supported yet"); + Payload::RouterKey(key) => { + if version >= 1 { + self.send_router_key(key, announce).await?; + } } - Payload::Aspa(_) => { - warn!("aspa payload not supported yet"); + Payload::Aspa(aspa) => { + if version >= 2 { + self.send_aspa(aspa, announce).await?; + } } } Ok(()) @@ -362,17 +956,252 @@ where Ok(()) } + async fn send_router_key(&mut self, key: &RouterKey, announce: bool) -> Result<()> { + let version = self.version()?; + + let flags = Flags::new(if announce { + ANNOUNCE_FLAG + } else { + WITHDRAW_FLAG + }); + + let pdu = RouterKeyPdu::new( + version, + flags, + key.ski(), + key.asn(), + Arc::<[u8]>::from(key.spki().to_vec()), + ); + + pdu.write(&mut self.stream).await?; + Ok(()) + } + + async fn send_aspa(&mut self, aspa: &Aspa, announce: bool) -> Result<()> { + let version = self.version()?; + + if announce { + aspa.validate_announcement()?; + } + + let flags = Flags::new(if announce { + ANNOUNCE_FLAG + } else { + WITHDRAW_FLAG + }); + + let providers = if announce { + aspa + .provider_asns() + .iter() + .map(|asn| asn.into_u32()) + .collect::>() + } else { + Vec::new() + }; + + let pdu = AspaPdu::new( + version, + flags, + aspa.customer_asn().into_u32(), + providers, + ); + + pdu.write(&mut self.stream).await?; + Ok(()) + } + + async fn send_error( &mut self, version: u8, code: ErrorCode, - offending_header: Option<&Header>, + offending_pdu: &[u8], text: &[u8], ) -> io::Result<()> { - let offending = offending_header.map(|h| h.as_ref()).unwrap_or(&[]); - - ErrorReport::new(version, code.as_u16(), offending, text) + let text_preview = String::from_utf8_lossy(text); + warn!( + "RTR session sending ErrorReport: version={}, error_code={}({}), offending_pdu_len={}, text={}", + version, + code.as_u16(), + code.description(), + offending_pdu.len(), + text_preview + ); + ErrorReport::new(version, code.as_u16(), offending_pdu, text) .write(&mut self.stream) .await } -} \ No newline at end of file + + async fn handle_pdu_read_error( + &mut self, + header: Header, + err: io::Error, + ) -> Result<()> { + warn!( + "RTR session failed to read established-session PDU payload: pdu={}, version={}, err={}", + header.pdu(), + header.version(), + err + ); + if err.kind() == io::ErrorKind::InvalidData { + let offending = self.read_full_pdu_bytes(header).await?; + let version = self.version()?; + let detail = err.to_string(); + let _ = self + .send_corrupt_data(version, &offending, detail.as_bytes()) + .await; + } + + self.state = SessionState::Closed; + Ok(()) + } + + async fn handle_first_pdu_read_error( + &mut self, + header: Header, + err: io::Error, + ) -> Result<()> { + warn!( + "RTR session failed to read first PDU payload: pdu={}, version={}, err={}", + header.pdu(), + header.version(), + err + ); + if err.kind() == io::ErrorKind::InvalidData { + let offending = self.read_full_pdu_bytes(header).await?; + let err_version = if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION) + .contains(&header.version()) + { + header.version() + } else { + SUPPORTED_MAX_VERSION + }; + let detail = err.to_string(); + + let _ = self + .send_corrupt_data(err_version, &offending, detail.as_bytes()) + .await; + } + + self.state = SessionState::Closed; + Ok(()) + } + + async fn handle_header_read_error( + &mut self, + raw_header: [u8; HEADER_LEN], + err: io::Error, + ) -> Result<()> { + warn!( + "RTR session handling invalid header bytes: raw_header={:02X?}, err={}", + raw_header, + err + ); + if err.kind() == io::ErrorKind::InvalidData { + let version = match self.version { + Some(version) => version, + None if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION).contains(&raw_header[0]) => { + raw_header[0] + } + None => SUPPORTED_MAX_VERSION, + }; + let detail = err.to_string(); + + let _ = self + .send_corrupt_data(version, &raw_header, detail.as_bytes()) + .await; + } + + self.state = SessionState::Closed; + Ok(()) + } + + async fn handle_transport_timeout(&mut self, offending_pdu: &[u8]) -> Result<()> { + let version = self.version.unwrap_or(SUPPORTED_MAX_VERSION); + let timeout = self.transport_timeout(); + let detail = format!( + "transport stalled for longer than {:?}", + timeout + ); + warn!( + "RTR session transport timeout: version={}, offending_pdu_len={}, timeout={:?}", + version, + offending_pdu.len(), + timeout + ); + + let _ = self + .send_error( + version, + ErrorCode::TransportFailed, + offending_pdu, + detail.as_bytes(), + ) + .await; + + self.state = SessionState::Closed; + Ok(()) + } + + async fn read_full_pdu_bytes(&mut self, header: Header) -> io::Result> { + let total_len = header.pdu_len()?; + let Some(payload_len) = total_len.checked_sub(HEADER_LEN) else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "PDU size smaller than header size", + )); + }; + + let mut bytes = Vec::with_capacity(total_len); + bytes.extend_from_slice(header.as_ref()); + bytes.resize(total_len, 0); + timeout( + self.transport_timeout(), + self.stream.read_exact(&mut bytes[HEADER_LEN..HEADER_LEN + payload_len]), + ) + .await + .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "transport read timed out"))??; + Ok(bytes) + } +} + +fn count_payloads(payloads: &[Payload]) -> (usize, usize, usize) { + let mut route_origins = 0; + let mut router_keys = 0; + let mut aspas = 0; + + for payload in payloads { + match payload { + Payload::RouteOrigin(_) => route_origins += 1, + Payload::RouterKey(_) => router_keys += 1, + Payload::Aspa(_) => aspas += 1, + } + } + + (route_origins, router_keys, aspas) +} + +fn count_payload_updates(updates: &[(bool, Payload)]) -> (usize, usize, usize, usize, usize) { + let mut announced = 0; + let mut withdrawn = 0; + let mut route_origins = 0; + let mut router_keys = 0; + let mut aspas = 0; + + for (announce, payload) in updates { + if *announce { + announced += 1; + } else { + withdrawn += 1; + } + + match payload { + Payload::RouteOrigin(_) => route_origins += 1, + Payload::RouterKey(_) => router_keys += 1, + Payload::Aspa(_) => aspas += 1, + } + } + + (announced, withdrawn, route_origins, router_keys, aspas) +} diff --git a/src/rtr/state.rs b/src/rtr/state.rs index b33c23a..e8f909f 100644 --- a/src/rtr/state.rs +++ b/src/rtr/state.rs @@ -1,14 +1,16 @@ +use crate::rtr::cache::SessionIds; + use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct State { - session_id: u16, + session_ids: SessionIds, serial: u32, } impl State { - pub fn session_id(self) -> u16 { - self.session_id + pub fn session_ids(self) -> SessionIds { + self.session_ids } pub fn serial(self) -> u32 { diff --git a/src/rtr/store.rs b/src/rtr/store.rs new file mode 100644 index 0000000..6434de8 --- /dev/null +++ b/src/rtr/store.rs @@ -0,0 +1,619 @@ +use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch}; +use anyhow::{anyhow, Result}; +use serde::{de::DeserializeOwned, Serialize}; +use std::path::Path; +use std::sync::Arc; +use tokio::task; +use tracing::{debug, info, warn}; + +use crate::rtr::cache::{CacheAvailability, Delta, SessionIds, Snapshot}; +use crate::rtr::state::State; + +const CF_META: &str = "meta"; +const CF_SNAPSHOT: &str = "snapshot"; +const CF_DELTA: &str = "delta"; + +const META_STATE: &[u8] = b"state"; +const META_SESSION_IDS: &[u8] = b"session_ids"; +const META_SERIAL: &[u8] = b"serial"; +const META_AVAILABILITY: &[u8] = b"availability"; +const META_DELTA_MIN: &[u8] = b"delta_min"; +const META_DELTA_MAX: &[u8] = b"delta_max"; + +const DELTA_KEY_PREFIX: u8 = b'd'; + +fn delta_key(serial: u32) -> [u8; 5] { + let mut key = [0u8; 5]; + key[0] = DELTA_KEY_PREFIX; + key[1..].copy_from_slice(&serial.to_be_bytes()); + key +} + +fn delta_key_serial(key: &[u8]) -> Option { + if key.len() != 5 || key[0] != DELTA_KEY_PREFIX { + return None; + } + let mut bytes = [0u8; 4]; + bytes.copy_from_slice(&key[1..]); + Some(u32::from_be_bytes(bytes)) +} + +#[derive(Clone)] +pub struct RtrStore { + db: Arc, +} + +impl RtrStore { + /// Open or create DB with required column families. + pub fn open>(path: P) -> Result { + let path_ref = path.as_ref(); + let mut opts = Options::default(); + opts.create_if_missing(true); + opts.create_missing_column_families(true); + + let cfs = vec![ + ColumnFamilyDescriptor::new(CF_META, Options::default()), + ColumnFamilyDescriptor::new(CF_SNAPSHOT, Options::default()), + ColumnFamilyDescriptor::new(CF_DELTA, Options::default()), + ]; + + info!("opening RTR RocksDB store at {}", path_ref.display()); + let db = Arc::new(DB::open_cf_descriptors(&opts, path_ref, cfs)?); + info!("opened RTR RocksDB store at {}", path_ref.display()); + + Ok(Self { db }) + } + + /// Common serialize/put. + fn put_cf(&self, cf: &str, key: &[u8], value: &T) -> Result<()> { + let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + let data = serde_json::to_vec(value)?; + self.db.put_cf(cf_handle, key, data)?; + Ok(()) + } + + /// Common get/deserialize. + fn get_cf(&self, cf: &str, key: &[u8]) -> Result> { + let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + if let Some(value) = self.db.get_cf(cf_handle, key)? { + let obj = serde_json::from_slice(&value)?; + Ok(Some(obj)) + } else { + Ok(None) + } + } + + /// Common delete. + fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> { + let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + self.db.delete_cf(cf_handle, key)?; + Ok(()) + } + + // =============================== + // Meta/state + // =============================== + + pub fn set_state(&self, state: &State) -> Result<()> { + self.put_cf(CF_META, META_STATE, &state) + } + + pub fn get_state(&self) -> Result> { + self.get_cf(CF_META, META_STATE) + } + + pub fn set_meta(&self, meta: &State) -> Result<()> { + self.set_state(meta) + } + + pub fn get_meta(&self) -> Result> { + self.get_state() + } + + pub fn set_session_ids(&self, session_ids: &SessionIds) -> Result<()> { + self.put_cf(CF_META, META_SESSION_IDS, session_ids) + } + + pub fn get_session_ids(&self) -> Result> { + self.get_cf(CF_META, META_SESSION_IDS) + } + + pub fn set_serial(&self, serial: u32) -> Result<()> { + self.put_cf(CF_META, META_SERIAL, &serial) + } + + pub fn get_serial(&self) -> Result> { + self.get_cf(CF_META, META_SERIAL) + } + + pub fn set_availability(&self, availability: CacheAvailability) -> Result<()> { + self.put_cf(CF_META, META_AVAILABILITY, &availability) + } + + pub fn get_availability(&self) -> Result> { + self.get_cf(CF_META, META_AVAILABILITY) + } + + pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> { + debug!( + "RTR store persisting delta window metadata: min_serial={}, max_serial={}", + min_serial, + max_serial + ); + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let mut batch = WriteBatch::default(); + batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?); + batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?); + self.db.write(batch)?; + Ok(()) + } + + pub fn clear_delta_window(&self) -> Result<()> { + debug!("RTR store clearing delta window metadata"); + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let mut batch = WriteBatch::default(); + batch.delete_cf(meta_cf, META_DELTA_MIN); + batch.delete_cf(meta_cf, META_DELTA_MAX); + self.db.write(batch)?; + Ok(()) + } + + pub fn get_delta_window(&self) -> Result> { + let min: Option = self.get_cf(CF_META, META_DELTA_MIN)?; + let max: Option = self.get_cf(CF_META, META_DELTA_MAX)?; + + match (min, max) { + (Some(min), Some(max)) => { + debug!( + "RTR store loaded delta window metadata: min_serial={}, max_serial={}", + min, + max + ); + Ok(Some((min, max))) + } + (None, None) => Ok(None), + _ => Err(anyhow!("Inconsistent DB state: delta window mismatch")), + } + } + + pub fn delete_state(&self) -> Result<()> { + self.delete_cf(CF_META, META_STATE) + } + + pub fn delete_serial(&self) -> Result<()> { + self.delete_cf(CF_META, META_SERIAL) + } + + // =============================== + // Snapshot + // =============================== + + pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> { + let cf_handle = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let mut batch = WriteBatch::default(); + let data = serde_json::to_vec(snapshot)?; + batch.put_cf(cf_handle, b"current", data); + self.db.write(batch)?; + Ok(()) + } + + pub fn get_snapshot(&self) -> Result> { + self.get_cf(CF_SNAPSHOT, b"current") + } + + pub fn delete_snapshot(&self) -> Result<()> { + self.delete_cf(CF_SNAPSHOT, b"current") + } + + pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> { + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let mut batch = WriteBatch::default(); + + batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); + batch.put_cf(meta_cf, META_STATE, serde_json::to_vec(state)?); + batch.put_cf( + meta_cf, + META_SESSION_IDS, + serde_json::to_vec(&state.clone().session_ids())?, + ); + batch.put_cf( + meta_cf, + META_SERIAL, + serde_json::to_vec(&state.clone().serial())?, + ); + + self.db.write(batch)?; + Ok(()) + } + + pub fn save_snapshot_and_meta( + &self, + snapshot: &Snapshot, + session_ids: &SessionIds, + serial: u32, + ) -> Result<()> { + let mut batch = WriteBatch::default(); + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + + batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); + batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?); + batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); + self.db.write(batch)?; + Ok(()) + } + + pub fn save_cache_state( + &self, + availability: CacheAvailability, + snapshot: &Snapshot, + session_ids: &SessionIds, + serial: u32, + delta: Option<&Delta>, + delta_window: Option<(u32, u32)>, + clear_delta_window: bool, + ) -> Result<()> { + debug!( + "RTR store save_cache_state start: availability={:?}, serial={}, session_ids={:?}, delta_present={}, delta_window={:?}, clear_delta_window={}, snapshot(route_origins={}, router_keys={}, aspas={})", + availability, + serial, + session_ids, + delta.is_some(), + delta_window, + clear_delta_window, + snapshot.origins().len(), + snapshot.router_keys().len(), + snapshot.aspas().len() + ); + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let delta_cf = self.db.cf_handle(CF_DELTA).ok_or_else(|| anyhow!("CF_DELTA not found"))?; + let mut batch = WriteBatch::default(); + + batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); + batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?); + batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); + batch.put_cf(meta_cf, META_AVAILABILITY, serde_json::to_vec(&availability)?); + + if let Some(delta) = delta { + debug!( + "RTR store persisting delta: serial={}, announced={}, withdrawn={}", + delta.serial(), + delta.announced().len(), + delta.withdrawn().len() + ); + batch.put_cf(delta_cf, delta_key(delta.serial()), serde_json::to_vec(delta)?); + } + + if clear_delta_window { + let existing_keys = self.list_delta_keys()?; + let existing_serials = summarize_delta_keys(&existing_keys); + info!( + "RTR store clearing persisted delta window: deleting {} delta records, serials={}", + existing_keys.len(), + existing_serials + ); + batch.delete_cf(meta_cf, META_DELTA_MIN); + batch.delete_cf(meta_cf, META_DELTA_MAX); + for key in existing_keys { + batch.delete_cf(delta_cf, key); + } + } else if let Some((min_serial, max_serial)) = delta_window { + batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?); + batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?); + // Serial numbers are compared in RFC 1982 ring order, while RocksDB stores + // keys in plain lexicographic order. After wraparound, a window such as + // [u32::MAX, 1] is contiguous in serial space but split in key space, so a + // simple delete_range() would leave stale high-serial keys behind. + let stale_keys = self.list_delta_keys_outside_window(min_serial, max_serial)?; + if !stale_keys.is_empty() { + info!( + "RTR store pruning stale delta records outside window [{}, {}]: count={}, serials={}", + min_serial, + max_serial, + stale_keys.len(), + summarize_delta_keys(&stale_keys) + ); + } else { + debug!( + "RTR store found no stale delta records outside window [{}, {}]", + min_serial, + max_serial + ); + } + for key in stale_keys { + batch.delete_cf(delta_cf, key); + } + } + + self.db.write(batch)?; + debug!("RTR store save_cache_state completed: serial={}", serial); + Ok(()) + } + + pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> { + let mut batch = WriteBatch::default(); + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); + batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); + self.db.write(batch)?; + Ok(()) + } + + pub async fn save_snapshot_and_serial_async( + self: Arc, + snapshot: Snapshot, + serial: u32, + ) -> Result<()> { + let snapshot_bytes = serde_json::to_vec(&snapshot)?; + let serial_bytes = serde_json::to_vec(&serial)?; + + task::spawn_blocking(move || { + let mut batch = WriteBatch::default(); + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + batch.put_cf(snapshot_cf, b"current", snapshot_bytes); + batch.put_cf(meta_cf, META_SERIAL, serial_bytes); + self.db.write(batch)?; + Ok::<_, anyhow::Error>(()) + }) + .await??; + + Ok(()) + } + + pub fn load_snapshot_and_state(&self) -> Result> { + let snapshot: Option = self.get_snapshot()?; + let state: Option = self.get_state()?; + match (snapshot, state) { + (Some(snap), Some(state)) => Ok(Some((snap, state))), + (None, None) => Ok(None), + _ => Err(anyhow!("Inconsistent DB state: snapshot and state mismatch")), + } + } + + pub fn load_snapshot_and_serial(&self) -> Result> { + let snapshot: Option = self.get_snapshot()?; + let serial: Option = self.get_serial()?; + match (snapshot, serial) { + (Some(snap), Some(serial)) => Ok(Some((snap, serial))), + (None, None) => Ok(None), + _ => Err(anyhow!("Inconsistent DB state: snapshot and serial mismatch")), + } + } + + // =============================== + // Delta + // =============================== + + pub fn save_delta(&self, delta: &Delta) -> Result<()> { + self.put_cf(CF_DELTA, &delta_key(delta.serial()), delta) + } + + pub fn get_delta(&self, serial: u32) -> Result> { + self.get_cf(CF_DELTA, &delta_key(serial)) + } + + pub fn load_deltas_since(&self, serial: u32) -> Result> { + let cf_handle = self + .db + .cf_handle(CF_DELTA) + .ok_or_else(|| anyhow!("CF_DELTA not found"))?; + + let start_key = delta_key(serial.wrapping_add(1)); + let iter = self.db.iterator_cf( + cf_handle, + IteratorMode::From(&start_key, Direction::Forward), + ); + + let mut out = Vec::new(); + + for item in iter { + let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; + + let parsed = delta_key_serial(key.as_ref()) + .ok_or_else(|| anyhow!("Invalid delta key"))?; + + if parsed <= serial { + continue; + } + + let delta: Delta = serde_json::from_slice(value.as_ref())?; + out.push(delta); + } + + Ok(out) + } + + pub fn load_delta_window(&self, min_serial: u32, max_serial: u32) -> Result> { + info!( + "RTR store loading persisted delta window: min_serial={}, max_serial={}", + min_serial, + max_serial + ); + let cf_handle = self + .db + .cf_handle(CF_DELTA) + .ok_or_else(|| anyhow!("CF_DELTA not found"))?; + let iter = self.db.iterator_cf(cf_handle, IteratorMode::Start); + let mut out = Vec::new(); + + for item in iter { + let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; + let parsed = delta_key_serial(key.as_ref()) + .ok_or_else(|| anyhow!("Invalid delta key"))?; + + // Restore by the persisted window bounds instead of load_deltas_since(). + // The latter follows lexicographic key order and is not safe across serial + // wraparound, where older high-valued keys may otherwise be pulled back in. + if serial_in_window(parsed, min_serial, max_serial) { + let delta: Delta = serde_json::from_slice(value.as_ref())?; + out.push(delta); + } + } + + out.sort_by_key(|delta| delta.serial().wrapping_sub(min_serial)); + debug!( + "RTR store loaded delta candidates for window [{}, {}]: count={}, serials={}", + min_serial, + max_serial, + out.len(), + summarize_delta_serials(&out) + ); + validate_delta_window(&out, min_serial, max_serial)?; + info!( + "RTR store restored valid delta window: min_serial={}, max_serial={}, count={}, serials={}", + min_serial, + max_serial, + out.len(), + summarize_delta_serials(&out) + ); + Ok(out) + } + + pub fn delete_delta(&self, serial: u32) -> Result<()> { + self.delete_cf(CF_DELTA, &delta_key(serial)) + } + + fn list_delta_keys(&self) -> Result>> { + let cf_handle = self + .db + .cf_handle(CF_DELTA) + .ok_or_else(|| anyhow!("CF_DELTA not found"))?; + let iter = self.db.iterator_cf(cf_handle, IteratorMode::Start); + let mut keys = Vec::new(); + + for item in iter { + let (key, _value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; + keys.push(key.to_vec()); + } + + Ok(keys) + } + + fn list_delta_keys_outside_window(&self, min_serial: u32, max_serial: u32) -> Result>> { + let cf_handle = self + .db + .cf_handle(CF_DELTA) + .ok_or_else(|| anyhow!("CF_DELTA not found"))?; + let iter = self.db.iterator_cf(cf_handle, IteratorMode::Start); + let mut keys = Vec::new(); + + for item in iter { + let (key, _value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; + let serial = delta_key_serial(key.as_ref()) + .ok_or_else(|| anyhow!("Invalid delta key"))?; + if !serial_in_window(serial, min_serial, max_serial) { + keys.push(key.to_vec()); + } + } + + Ok(keys) + } +} + +fn serial_in_window(serial: u32, min_serial: u32, max_serial: u32) -> bool { + serial.wrapping_sub(min_serial) <= max_serial.wrapping_sub(min_serial) +} + +fn validate_delta_window(deltas: &[Delta], min_serial: u32, max_serial: u32) -> Result<()> { + if deltas.is_empty() { + warn!( + "RTR store delta window validation failed: no persisted deltas for window [{}, {}]", + min_serial, + max_serial + ); + return Err(anyhow!( + "delta window [{}, {}] has no persisted deltas", + min_serial, + max_serial + )); + } + + if deltas.first().map(Delta::serial) != Some(min_serial) { + warn!( + "RTR store delta window validation failed: first delta mismatch for window [{}, {}], got {:?}", + min_serial, + max_serial, + deltas.first().map(Delta::serial) + ); + return Err(anyhow!( + "delta window starts at {}, but first persisted delta is {:?}", + min_serial, + deltas.first().map(Delta::serial) + )); + } + + if deltas.last().map(Delta::serial) != Some(max_serial) { + warn!( + "RTR store delta window validation failed: last delta mismatch for window [{}, {}], got {:?}", + min_serial, + max_serial, + deltas.last().map(Delta::serial) + ); + return Err(anyhow!( + "delta window ends at {}, but last persisted delta is {:?}", + max_serial, + deltas.last().map(Delta::serial) + )); + } + + for pair in deltas.windows(2) { + if pair[1].serial() != pair[0].serial().wrapping_add(1) { + warn!( + "RTR store delta window validation failed: non-contiguous pair {} -> {} within window [{}, {}]", + pair[0].serial(), + pair[1].serial(), + min_serial, + max_serial + ); + return Err(anyhow!( + "persisted deltas are not contiguous: {} -> {}", + pair[0].serial(), + pair[1].serial() + )); + } + } + + Ok(()) +} + +fn summarize_delta_keys(keys: &[Vec]) -> String { + let serials: Vec = keys + .iter() + .filter_map(|key| delta_key_serial(key)) + .collect(); + summarize_serials(&serials) +} + +fn summarize_delta_serials(deltas: &[Delta]) -> String { + let serials: Vec = deltas.iter().map(Delta::serial).collect(); + summarize_serials(&serials) +} + +fn summarize_serials(serials: &[u32]) -> String { + const MAX_INLINE: usize = 12; + + if serials.is_empty() { + return "[]".to_string(); + } + + if serials.len() <= MAX_INLINE { + return format!("{:?}", serials); + } + + let head: Vec = serials.iter().take(6).copied().collect(); + let tail: Vec = serials + .iter() + .rev() + .take(3) + .copied() + .collect::>() + .into_iter() + .rev() + .collect(); + + format!("{:?} ... {:?} (total={})", head, tail, serials.len()) +} diff --git a/src/rtr/store_db.rs b/src/rtr/store_db.rs deleted file mode 100644 index e824032..0000000 --- a/src/rtr/store_db.rs +++ /dev/null @@ -1,310 +0,0 @@ -use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch}; -use anyhow::{anyhow, Result}; -use serde::{de::DeserializeOwned, Serialize}; -use std::path::Path; -use std::sync::Arc; -use tokio::task; - -use crate::rtr::cache::{Delta, Snapshot}; -use crate::rtr::state::State; - -const CF_META: &str = "meta"; -const CF_SNAPSHOT: &str = "snapshot"; -const CF_DELTA: &str = "delta"; - -const META_STATE: &[u8] = b"state"; -const META_SESSION_ID: &[u8] = b"session_id"; -const META_SERIAL: &[u8] = b"serial"; -const META_DELTA_MIN: &[u8] = b"delta_min"; -const META_DELTA_MAX: &[u8] = b"delta_max"; - -const DELTA_KEY_PREFIX: u8 = b'd'; - -fn delta_key(serial: u32) -> [u8; 5] { - let mut key = [0u8; 5]; - key[0] = DELTA_KEY_PREFIX; - key[1..].copy_from_slice(&serial.to_be_bytes()); - key -} - -fn delta_key_serial(key: &[u8]) -> Option { - if key.len() != 5 || key[0] != DELTA_KEY_PREFIX { - return None; - } - let mut bytes = [0u8; 4]; - bytes.copy_from_slice(&key[1..]); - Some(u32::from_be_bytes(bytes)) -} - -#[derive(Clone)] -pub struct RtrStore { - db: Arc, -} - -impl RtrStore { - /// Open or create DB with required column families. - pub fn open>(path: P) -> Result { - let mut opts = Options::default(); - opts.create_if_missing(true); - opts.create_missing_column_families(true); - - let cfs = vec![ - ColumnFamilyDescriptor::new(CF_META, Options::default()), - ColumnFamilyDescriptor::new(CF_SNAPSHOT, Options::default()), - ColumnFamilyDescriptor::new(CF_DELTA, Options::default()), - ]; - - let db = Arc::new(DB::open_cf_descriptors(&opts, path, cfs)?); - - Ok(Self { db }) - } - - /// Common serialize/put. - fn put_cf(&self, cf: &str, key: &[u8], value: &T) -> Result<()> { - let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; - let data = serde_json::to_vec(value)?; - self.db.put_cf(cf_handle, key, data)?; - Ok(()) - } - - /// Common get/deserialize. - fn get_cf(&self, cf: &str, key: &[u8]) -> Result> { - let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; - if let Some(value) = self.db.get_cf(cf_handle, key)? { - let obj = serde_json::from_slice(&value)?; - Ok(Some(obj)) - } else { - Ok(None) - } - } - - /// Common delete. - fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> { - let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; - self.db.delete_cf(cf_handle, key)?; - Ok(()) - } - - // =============================== - // Meta/state - // =============================== - - pub fn set_state(&self, state: &State) -> Result<()> { - self.put_cf(CF_META, META_STATE, &state) - } - - pub fn get_state(&self) -> Result> { - self.get_cf(CF_META, META_STATE) - } - - pub fn set_meta(&self, meta: &State) -> Result<()> { - self.set_state(meta) - } - - pub fn get_meta(&self) -> Result> { - self.get_state() - } - - pub fn set_session_id(&self, session_id: u16) -> Result<()> { - self.put_cf(CF_META, META_SESSION_ID, &session_id) - } - - pub fn get_session_id(&self) -> Result> { - self.get_cf(CF_META, META_SESSION_ID) - } - - pub fn set_serial(&self, serial: u32) -> Result<()> { - self.put_cf(CF_META, META_SERIAL, &serial) - } - - pub fn get_serial(&self) -> Result> { - self.get_cf(CF_META, META_SERIAL) - } - - pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> { - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; - let mut batch = WriteBatch::default(); - batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?); - batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?); - self.db.write(batch)?; - Ok(()) - } - - pub fn get_delta_window(&self) -> Result> { - let min: Option = self.get_cf(CF_META, META_DELTA_MIN)?; - let max: Option = self.get_cf(CF_META, META_DELTA_MAX)?; - - match (min, max) { - (Some(min), Some(max)) => Ok(Some((min, max))), - (None, None) => Ok(None), - _ => Err(anyhow!("Inconsistent DB state: delta window mismatch")), - } - } - - pub fn delete_state(&self) -> Result<()> { - self.delete_cf(CF_META, META_STATE) - } - - pub fn delete_serial(&self) -> Result<()> { - self.delete_cf(CF_META, META_SERIAL) - } - - // =============================== - // Snapshot - // =============================== - - pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> { - let cf_handle = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let mut batch = WriteBatch::default(); - let data = serde_json::to_vec(snapshot)?; - batch.put_cf(cf_handle, b"current", data); - self.db.write(batch)?; - Ok(()) - } - - pub fn get_snapshot(&self) -> Result> { - self.get_cf(CF_SNAPSHOT, b"current") - } - - pub fn delete_snapshot(&self) -> Result<()> { - self.delete_cf(CF_SNAPSHOT, b"current") - } - - pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> { - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; - let mut batch = WriteBatch::default(); - - batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); - batch.put_cf(meta_cf, META_STATE, serde_json::to_vec(state)?); - batch.put_cf( - meta_cf, - META_SESSION_ID, - serde_json::to_vec(&state.clone().session_id())?, - ); - batch.put_cf( - meta_cf, - META_SERIAL, - serde_json::to_vec(&state.clone().serial())?, - ); - - self.db.write(batch)?; - Ok(()) - } - - pub fn save_snapshot_and_meta( - &self, - snapshot: &Snapshot, - session_id: u16, - serial: u32, - ) -> Result<()> { - let mut batch = WriteBatch::default(); - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; - - batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); - batch.put_cf(meta_cf, META_SESSION_ID, serde_json::to_vec(&session_id)?); - batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); - self.db.write(batch)?; - Ok(()) - } - - pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> { - let mut batch = WriteBatch::default(); - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; - batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); - batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); - self.db.write(batch)?; - Ok(()) - } - - pub async fn save_snapshot_and_serial_async( - self: Arc, - snapshot: Snapshot, - serial: u32, - ) -> Result<()> { - let snapshot_bytes = serde_json::to_vec(&snapshot)?; - let serial_bytes = serde_json::to_vec(&serial)?; - - task::spawn_blocking(move || { - let mut batch = WriteBatch::default(); - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; - batch.put_cf(snapshot_cf, b"current", snapshot_bytes); - batch.put_cf(meta_cf, META_SERIAL, serial_bytes); - self.db.write(batch)?; - Ok::<_, anyhow::Error>(()) - }) - .await??; - - Ok(()) - } - - pub fn load_snapshot_and_state(&self) -> Result> { - let snapshot: Option = self.get_snapshot()?; - let state: Option = self.get_state()?; - match (snapshot, state) { - (Some(snap), Some(state)) => Ok(Some((snap, state))), - (None, None) => Ok(None), - _ => Err(anyhow!("Inconsistent DB state: snapshot and state mismatch")), - } - } - - pub fn load_snapshot_and_serial(&self) -> Result> { - let snapshot: Option = self.get_snapshot()?; - let serial: Option = self.get_serial()?; - match (snapshot, serial) { - (Some(snap), Some(serial)) => Ok(Some((snap, serial))), - (None, None) => Ok(None), - _ => Err(anyhow!("Inconsistent DB state: snapshot and serial mismatch")), - } - } - - // =============================== - // Delta - // =============================== - - pub fn save_delta(&self, delta: &Delta) -> Result<()> { - self.put_cf(CF_DELTA, &delta_key(delta.serial()), delta) - } - - pub fn get_delta(&self, serial: u32) -> Result> { - self.get_cf(CF_DELTA, &delta_key(serial)) - } - - pub fn load_deltas_since(&self, serial: u32) -> Result> { - let cf_handle = self - .db - .cf_handle(CF_DELTA) - .ok_or_else(|| anyhow!("CF_DELTA not found"))?; - - let start_key = delta_key(serial.wrapping_add(1)); - let iter = self.db.iterator_cf( - cf_handle, - IteratorMode::From(&start_key, Direction::Forward), - ); - - let mut out = Vec::new(); - - for item in iter { - let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; - - let parsed = delta_key_serial(key.as_ref()) - .ok_or_else(|| anyhow!("Invalid delta key"))?; - - if parsed <= serial { - continue; - } - - let delta: Delta = serde_json::from_slice(value.as_ref())?; - out.push(delta); - } - - Ok(out) - } - - pub fn delete_delta(&self, serial: u32) -> Result<()> { - self.delete_cf(CF_DELTA, &delta_key(serial)) - } -} diff --git a/tests/common/test_helper.rs b/tests/common/test_helper.rs index 9ac9cc3..bfb1676 100644 --- a/tests/common/test_helper.rs +++ b/tests/common/test_helper.rs @@ -1,12 +1,12 @@ -use std::net::{Ipv4Addr, Ipv6Addr}; use std::fmt::Write; +use std::net::{Ipv4Addr, Ipv6Addr}; use serde_json::{json, Value}; use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; +use rpki::rtr::cache::SerialResult; use rpki::rtr::payload::{Payload, RouteOrigin}; use rpki::rtr::pdu::{CacheResponse, EndOfDataV1, IPv4Prefix, IPv6Prefix}; -use rpki::rtr::cache::SerialResult; pub struct RtrDebugDumper { entries: Vec, @@ -163,12 +163,7 @@ pub fn v4_origin( RouteOrigin::new(prefix, max_len, asn.into()) } -pub fn v6_origin( - addr: Ipv6Addr, - prefix_len: u8, - max_len: u8, - asn: u32, -) -> RouteOrigin { +pub fn v6_origin(addr: Ipv6Addr, prefix_len: u8, max_len: u8, asn: u32) -> RouteOrigin { let prefix = v6_prefix(addr, prefix_len); RouteOrigin::new(prefix, max_len, asn.into()) } @@ -236,9 +231,8 @@ pub fn serial_result_to_string(result: &SerialResult) -> String { match result { SerialResult::UpToDate => "UpToDate".to_string(), SerialResult::ResetRequired => "ResetRequired".to_string(), - SerialResult::Deltas(deltas) => { - let serials: Vec = deltas.iter().map(|d| d.serial()).collect(); - format!("Deltas {:?}", serials) + SerialResult::Delta(delta) => { + format!("Delta serial={}", delta.serial()) } } } @@ -266,12 +260,7 @@ pub fn print_snapshot_hashes(label: &str, snapshot: &rpki::rtr::cache::Snapshot) ); } -pub fn test_report( - name: &str, - purpose: &str, - input: &str, - output: &str, -) { +pub fn test_report(name: &str, purpose: &str, input: &str, output: &str) { println!( "\n==================== TEST REPORT ====================\n测试名称: {}\n测试目的: {}\n\n【输入】\n{}\n【输出】\n{}\n====================================================\n", name, purpose, input, output @@ -307,16 +296,14 @@ pub fn serial_result_detail_to_string(result: &rpki::rtr::cache::SerialResult) - rpki::rtr::cache::SerialResult::ResetRequired => { " result: ResetRequired\n".to_string() } - rpki::rtr::cache::SerialResult::Deltas(deltas) => { + rpki::rtr::cache::SerialResult::Delta(delta) => { let mut out = String::new(); - let _ = writeln!(&mut out, " result: Deltas"); - for (idx, delta) in deltas.iter().enumerate() { - let _ = writeln!(&mut out, " delta[{}].serial: {}", idx, delta.serial()); - let _ = writeln!(&mut out, " delta[{}].announced:", idx); - out.push_str(&indent_block(&payloads_to_string(delta.announced()), 4)); - let _ = writeln!(&mut out, " delta[{}].withdrawn:", idx); - out.push_str(&indent_block(&payloads_to_string(delta.withdrawn()), 4)); - } + let _ = writeln!(&mut out, " result: Delta"); + let _ = writeln!(&mut out, " delta.serial: {}", delta.serial()); + let _ = writeln!(&mut out, " delta.announced:"); + out.push_str(&indent_block(&payloads_to_string(delta.announced()), 4)); + let _ = writeln!(&mut out, " delta.withdrawn:"); + out.push_str(&indent_block(&payloads_to_string(delta.withdrawn()), 4)); out } } diff --git a/tests/fixtures/tls/client-bad.cnf b/tests/fixtures/tls/client-bad.cnf new file mode 100644 index 0000000..4d5b987 --- /dev/null +++ b/tests/fixtures/tls/client-bad.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = dn +prompt = no +req_extensions = req_ext + +[dn] +CN = RTR Test Client Bad + +[req_ext] +subjectAltName = @alt_names +extendedKeyUsage = clientAuth + +[alt_names] +IP.1 = 127.0.0.2 diff --git a/tests/fixtures/tls/client-bad.crt b/tests/fixtures/tls/client-bad.crt new file mode 100644 index 0000000..52b7d5a --- /dev/null +++ b/tests/fixtures/tls/client-bad.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDMTCCAhmgAwIBAgIUOiwVulOYZB53/kZ6zcqPl0p0g9UwDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDUw +MFoXDTM2MDMyMDA4MDUwMFowHjEcMBoGA1UEAwwTUlRSIFRlc3QgQ2xpZW50IEJh +ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKatGSgyzw0FKy4n5pf6 +W7UUJCm26WI1+W1GsAkh2EfGofD3ynE6Ne15jRjLc/PeDe0xlEl4wTQm4Gs8kkil +JhHzG4fFxUyi4MeRuQPha419hZ+pK1BHUIlURbg1oeXO5g3kwfhLPbVAQcrJdvwz +uSEGj22OmqgQFmn0Uwl6xvl/unIIzL2FXGedQA4CitaXk9dqbQWmIFDKuh0q8Kim +ePmy+8aHGuD/SeeP31ZvGFv3WcvzsX5aysRdj/WWiJUWbv1vVJ8rApswgH1qC2gY +F9lQuEMtfgo+39cUioRzD/lkzl1LnNHr/a21WGOS8ojvAiK3RC49VABVs3fRYGhc +cTECAwEAAaNoMGYwDwYDVR0RBAgwBocEfwAAAjATBgNVHSUEDDAKBggrBgEFBQcD +AjAdBgNVHQ4EFgQUWk94GyULLgNLb+38lHzbrVEVrIMwHwYDVR0jBBgwFoAUtiTC +qdhd2TwFc7JZGWkF8BEcMJwwDQYJKoZIhvcNAQELBQADggEBACHR2a6IkoHt0iOQ +DjMi45fgybvddssC1B2aDABziQDabQkrei8leEFptgCCNtMSWkmrfgZJyI2tvuws +296kDhkhIEDiwsFh24ZblYaAuKmFBEmw/v6VotFeQ9Hrsf+KKOT3jGibIb2+3ho3 +q2X2e6ye7SOsfrs4hqeggcBQyzStSPoHE6KWNOHb6vBdJKavFsbff28mQIW3uEDK +j6xR7b6+xpuuuqwA41BulCTXIKGzDoIU1bet/7YmrlN69ElN97EfiT6u32dIqUN5 +dwQKRlTJdf/ZxXVvEemTUcUdRDdLLyG+RMhUfiAr56G+aBn8SzMVEjMGS7bB5dsa +FhEqWj4= +-----END CERTIFICATE----- diff --git a/tests/fixtures/tls/client-bad.key b/tests/fixtures/tls/client-bad.key new file mode 100644 index 0000000..849e702 --- /dev/null +++ b/tests/fixtures/tls/client-bad.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCmrRkoMs8NBSsu +J+aX+lu1FCQptuliNfltRrAJIdhHxqHw98pxOjXteY0Yy3Pz3g3tMZRJeME0JuBr +PJJIpSYR8xuHxcVMouDHkbkD4WuNfYWfqStQR1CJVEW4NaHlzuYN5MH4Sz21QEHK +yXb8M7khBo9tjpqoEBZp9FMJesb5f7pyCMy9hVxnnUAOAorWl5PXam0FpiBQyrod +KvCopnj5svvGhxrg/0nnj99Wbxhb91nL87F+WsrEXY/1loiVFm79b1SfKwKbMIB9 +agtoGBfZULhDLX4KPt/XFIqEcw/5ZM5dS5zR6/2ttVhjkvKI7wIit0QuPVQAVbN3 +0WBoXHExAgMBAAECggEAEtQwKNngiQSB9cNdKeMG9CRT14CJyeX2CPG41jdEEwut +5KZhcLwWNn8KQPhO34hAw4Bb05b4IHeZ15NupRU/AT0Y6ZQeb0NhKDpej8Ep4MCC +1WALfBSqVPE3tREd+nOOiphCj1WUoYKiTBHJRsFjEweGMWawMvcqSQroTIRZsPqd +uphHGiyF5KGIPMZv6UrdC9pMllQrdYp6edu5Jk/EhX85dFv96eRwI8tutzvJ5gkt +itpX9l6reeO6s0FszcgbvzU5VxCOSV1Oo6izjdpocxnnGfO81Lc+DFy0XwJ3PCAz +ABE9SVO6TzRMcUAr9x44WffIsG6Ea9tDTHqmGdsm4QKBgQDTsOL2a+yT4bGHwkXl +GdlSNk1aLiZR/b6L1KEjsGQ77mGS+EZYouIo7BmAt3pQ/YQ/yqp/+G3RIPSqkg/w +n1iloaNAsCS7fIiOMrAKfhX6Y93Yi3orUKx+f9Y8vgTqbVXuyo4n2IvmG0Dqm6uA +YuuTvG+J3lu1Tfy6FwcMAsAYOQKBgQDJkCrg+HMkfvSykAFvMw94DEtOiWU4/bF/ +r6k299i++ubx1hqk48QWuMaXPK9FnsMzCeHukoHxi4JTQvu5rxED8MLg3nKmsl78 +efLcMKuSkzDtGriI4EMnAPHTcnrc7fxneUg944EOxjb2AAm0mQx95sA+q4qFvEy6 +Yq06tIJwuQKBgB6o+/Zc40L36VKUXLM17zftDX8GOB6f9b0i6sPUhG/5ssAqnWWx +EbiDmZ3+9QRN852ZqOAoBx/G+ijKRuy+54P1yUNRP8C35L9TsBOU93HwjO0UJnmn +kZQwx8K8ctHRTCTtyXET+A8320sfiNNrgFtBa5Y4UmgMB5KcSzT+IPxhAoGAEGFk ++q92PAsNO820MCNIKItnO1SzIzSKzkOqTstJlAuz5QdvVuMjtm0Bxpyp6dCDMIyn +DcpeQREDYFzbNDXj/hv82mV5j86DJaWLdRWHe/v2R+6Z/JWtH2hWPsbY8Udt8cLL +eiwY+uhk4w0RvNmLSFgOW4l5UnEBE0ydo120FBECgYEAxTNChdhMIHnFQQHLYI+6 +SR6LqnjPwqdCQSzy4NrnGJ9O7c+Sed1L5RVg88OKG2dZo+GGjpsSAFRv3gB6rOX9 +bYxklbPWC65bbWcIMNK2LCJ6+LahZzBwz2tBZOA5GBs0DunQ0PdqLnxdi7QZoHfb +pAuCQadp+4BQHVsDwIwpUH4= +-----END PRIVATE KEY----- diff --git a/tests/fixtures/tls/client-ca.crt b/tests/fixtures/tls/client-ca.crt new file mode 100644 index 0000000..5094787 --- /dev/null +++ b/tests/fixtures/tls/client-ca.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDGzCCAgOgAwIBAgIUKCaM+eEU12YVn/tC6MEBYPj50oIwDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDQy +OFoXDTM2MDMyMDA4MDQyOFowHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENB +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmmS+gsRZXPaG20Ra2te7 +WVeqyeFX2lV/RvNsT3WW+EtwuTN2aFJ/pBgSCqgCaFocU5IcgfScybQrjmTWeEot +D0F1OOW/TihlBsrU3BS8vOZ4KTuwB9eOgpt7FL8CxLetjNWxZKBsvJMauOUYFLCt +DN+78EDqrZ2VyJPfa+7dJKS/JoRgs/WFjvAeb5x1UpKPf/gp1Ji+Vj92Ux9znYF1 +2Zfs77NsLdwcqyMVj7cb1nucutNf/vMqNi+BoSwC/tQkflda9bLUwto/ZfFPsCFw +Ptjwiwm+IGNAKG8r1ujaUVDtxc+NliuAbHERBgZGtYpAE6u28+d8cUAyGGQFNGqb +uQIDAQABo1MwUTAdBgNVHQ4EFgQUtiTCqdhd2TwFc7JZGWkF8BEcMJwwHwYDVR0j +BBgwFoAUtiTCqdhd2TwFc7JZGWkF8BEcMJwwDwYDVR0TAQH/BAUwAwEB/zANBgkq +hkiG9w0BAQsFAAOCAQEATr0wChmDvmw51RQ0kOFP1l3n+O1n/U0xR39YApnmdhtk +MyO9gNyqxdUWj4QaRkswr5iMyjsLrT3lYWGfh7oHlhIgkXEy5OK548rhj/PA6d5U +M7DDGI06EFfYXcC57Cx1y8Egd9gfkfPk0ned/TG+/dYtyb7sBiIKrOiDExbfZ53U +mNtfsdptBoOjee4KggplfVhkUopyO0KKD7twq4dvWgzWZHv7m5wrQA8GP8W4JpPP +ZAs8cOLfq4cxAwYFXEaf0x7is2u9KdagPqXkiPMeC8fOHWyRCR7sKN4Dnh57ruqq +mQXil9GGFfkDA/bh6+Rs93Yll7yYyA3xif6fSBmgUg== +-----END CERTIFICATE----- diff --git a/tests/fixtures/tls/client-ca.key b/tests/fixtures/tls/client-ca.key new file mode 100644 index 0000000..a7e9c3c --- /dev/null +++ b/tests/fixtures/tls/client-ca.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCaZL6CxFlc9obb +RFra17tZV6rJ4VfaVX9G82xPdZb4S3C5M3ZoUn+kGBIKqAJoWhxTkhyB9JzJtCuO +ZNZ4Si0PQXU45b9OKGUGytTcFLy85ngpO7AH146Cm3sUvwLEt62M1bFkoGy8kxq4 +5RgUsK0M37vwQOqtnZXIk99r7t0kpL8mhGCz9YWO8B5vnHVSko9/+CnUmL5WP3ZT +H3OdgXXZl+zvs2wt3ByrIxWPtxvWe5y601/+8yo2L4GhLAL+1CR+V1r1stTC2j9l +8U+wIXA+2PCLCb4gY0AobyvW6NpRUO3Fz42WK4BscREGBka1ikATq7bz53xxQDIY +ZAU0apu5AgMBAAECggEABjSKW6McnFnkLaffpvAIvZyCZr7B0yqghO9/qOnm+W++ +xhLFbYfzTVsSTo9WGW+Vt94leyujqY+uOHjhDdCdYwGUfobtW2zQMqewSnAi7cyt +g6q8dnQ5bBJnrfvHVrSzKvfju1GfTSz0Y/4BK8O2ENBlM1DIndW5kWgwEJx3EuPk +KzsjeZnWEUWKKcHL6/+jviTgfM7tnMjgtHMeHU/v8hVKBtebWAxvPjTE+HRoWePE +uEgvGC9eIHb7fRh0sEt4+vrx5GkKikwj4278yhTJTF4X4kTE8T1kBYaihAiSTyoy +jdf7DkGtfC//3k6yc3mSabq7UjMLu1m4HGCSXjs2iQKBgQDW/e+Qx7P/hyeya/gQ +JDAsZciBKxBn8gGywRbtEzFz08Wb1WMX0LaioUIVMXUmVzXWgQkaiztzS+AQTr+D +sRpbOqkxalUdzyaHzgUTKhyvU2dzPOWWFUwJvBPKShhUibukgl+BIIdXr701qKf5 +W/sL3e7CMxNfirKfMoXe4q4m1wKBgQC318k5+63TexbP/peY6A+bGudJo/PKE4aD +GexJEvUcEV1J5d3SKXZoe2V0fTmFsvdhJIPFg2rnVRy2tV6F1vFzHAiou3tc0skU +vlM8nTDwPI80kk44LRMzzqZoe5pziI3tUS/Qwu8FX3MfXkkKd0wnQRhoUc0fMdT2 +zOuvFehP7wKBgBQNsafeiNKf57sDySqwRXIOuGob+zbG4xOqYRoR/T3hlgAYIlsZ +U7/NrN1PNK8z2Ui91nyMWipB/I9o2QJOpbe2vAto8LGMHfry45RLDEvqSq78Eioy +qFoMGgh3ateP1Vnd80yXHSi3sr1rkud2he8wb1Hb88WoqUqiKsyEdlwXAoGBAIdN +V/rFoRv5BkQUAqx1di7YMQrAkIbTsfbA2Ga7fgunN/pQI94tx8iDsJp4IyKkIW6s +OhLecopI2LYba7KjC9aE9laAjP024OjUXlxI8CCO4XJ2jvzHJ8/EMjLJbVXEVXgo +fUFuhg11PzwB303FmRV20ijMs2NXAH6XOIoGXJCfAoGALCJ11wdb1m2vzYRLlSqo +mcDQ1PltSbnFZ6KyuJM5MnJ/TIwnUQ7Rnsjya5m4HFJtJnPSgrygckjrZwPLZpPz +TrKr614Ln4E3YU2IoVTyjjgtkEHn0fRcxvsn9z5vXCrzM4JJF5Ac+nbKEwgugmpu +JjMMsfGPFC/cr+SpRZr/Nf8= +-----END PRIVATE KEY----- diff --git a/tests/fixtures/tls/client-good.cnf b/tests/fixtures/tls/client-good.cnf new file mode 100644 index 0000000..46f8041 --- /dev/null +++ b/tests/fixtures/tls/client-good.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = dn +prompt = no +req_extensions = req_ext + +[dn] +CN = RTR Test Client Good + +[req_ext] +subjectAltName = @alt_names +extendedKeyUsage = clientAuth + +[alt_names] +IP.1 = 127.0.0.1 diff --git a/tests/fixtures/tls/client-good.crt b/tests/fixtures/tls/client-good.crt new file mode 100644 index 0000000..9f4e2ff --- /dev/null +++ b/tests/fixtures/tls/client-good.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDMjCCAhqgAwIBAgIUTLfIortKKjsYovPDuDVnUR8i2YMwDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDUw +MFoXDTM2MDMyMDA4MDUwMFowHzEdMBsGA1UEAwwUUlRSIFRlc3QgQ2xpZW50IEdv +b2QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUjnhxQxKUNLjGN6mu +/VvoGXfsyvZq38QMlfIauGPQ5Jwpq2BU2gW1zFvjBJaS8ZKY2DOsBKvZMrEiSxVt +8WCCgvSucw5Elwniq/Yu8UmkvWEMnk9oXwQuHQeWYSe2eMvwBMTNaj3Vz+xo2Ga/ ++eGClxPDcKKvOMYz+BRpyMfHzX1LLrM26yDjyxQKxzsrFMQ5IKXkhv4COLAjpLZd +oEDSkG+W4fbGTC6bSkskFP62OmnR/Qq6dXBVAJkzHqoiWEmKWBQy9JhEo+N3Jz7D +QfOAI/20MD6J4brWS56ciHL67yQjHJW6P0HlcLzefFqspIwvTFyCsF137jJi9sqL +y7UlAgMBAAGjaDBmMA8GA1UdEQQIMAaHBH8AAAEwEwYDVR0lBAwwCgYIKwYBBQUH +AwIwHQYDVR0OBBYEFJy0O+AR5EdWrUr2UHtPCUVTrRiEMB8GA1UdIwQYMBaAFLYk +wqnYXdk8BXOyWRlpBfARHDCcMA0GCSqGSIb3DQEBCwUAA4IBAQB8eavz/Pk3oJBh +sgy4ve6nLAqKGcD+zSoBxw/ErZaDNdUZZaFKT9nypfi39jkfjU+CIfhczCs5Cknk +EDqYtp1BXkY+auKeYRkoaCKv+ucnIc6JZ70NOQGdDNzU4eb9tVA9Py0j5VzvxWR1 +yl6vHTc3tTA873RRezvUe+SQDwR+x2dxZ0O0MFZR4CTaTeR+fpD4bE9mBcYN1KVk +HdA7USCwOZ0qE+aHTegM4pgGrpHYS5yBpUFADqfJl1yBnnqc10zkphoy6dRsbVuN +NkDostGbbfGdGjEPs6i1TMsDTVyOo4nsHACk4q4m/IEVBkiMwxmwyhK1EEWQmQpy +h0CqAs3Z +-----END CERTIFICATE----- diff --git a/tests/fixtures/tls/client-good.key b/tests/fixtures/tls/client-good.key new file mode 100644 index 0000000..07bada9 --- /dev/null +++ b/tests/fixtures/tls/client-good.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDUjnhxQxKUNLjG +N6mu/VvoGXfsyvZq38QMlfIauGPQ5Jwpq2BU2gW1zFvjBJaS8ZKY2DOsBKvZMrEi +SxVt8WCCgvSucw5Elwniq/Yu8UmkvWEMnk9oXwQuHQeWYSe2eMvwBMTNaj3Vz+xo +2Ga/+eGClxPDcKKvOMYz+BRpyMfHzX1LLrM26yDjyxQKxzsrFMQ5IKXkhv4COLAj +pLZdoEDSkG+W4fbGTC6bSkskFP62OmnR/Qq6dXBVAJkzHqoiWEmKWBQy9JhEo+N3 +Jz7DQfOAI/20MD6J4brWS56ciHL67yQjHJW6P0HlcLzefFqspIwvTFyCsF137jJi +9sqLy7UlAgMBAAECggEALmsOkm17WTJKR79QJw7dS0qEjgmk1qIXRkhYns01vyCt +mcv7NYyHQrRmPKV73Is04HwWjLJYdQ5E8KBFBcV4tgezN4WY0BHL7txu3sGCu58/ +2mmYHcriNs/QIF8HNSocH0ZrVBCngFHv5tWbWsFPJh2oCz5FyM41OpQqoQ9f0ZoA +PKq3m7dNidFmlQPhhJ9KeLYwDfzTRy4sWBDrgZFDQ6ut9nO8/JpJvxZCSnOdbW1T +6pYk3uAKKcmrcdI2fsg7u0J1zkrkDnRdrDv1QCY/h1UXvR4KuMeYRylHWJpOKUiR +asxpJO+lh7TmDaEIqO+dSmTHawcKDYuIp0EV4/x7AQKBgQD/cgMhx97bJ3ym0ucc +xtt10jo0NTW7SJIXx/eUdBQNA3aGWfsqLrBI5qKrXLxTXF/vAs3LWHiHF6EgRxW/ +7RbO+NZzAGxg53Y3PMqZ4AQbOm2JCIgT9ZhQc55+qowdtrts3njL+ayy3vWh3nPt +G78+rs4ZLCHhEutbkKqU7S3XSQKBgQDVBJ5l9d1QcqFqYnlmyhgfwJwpxo4N2dJR +18pCow8Mp8S9ROJx3OF4yNTRa6ipnX0oGdgchT63jXBDoKot3/4tqbjYj/Y3MmKL +crZ7dZEiAqkqE975qUfSlQS6BtknRKduzQzAJAKz0A+JbPIisPHIrBnG2PH9h2XA +TjA2rHRi/QKBgDIp65+IpqUW/g2swSIPky1yGWgDQwgCWl49MMuAeCeOFIqRxRcl +kAzg7fUFAx7Dtzsyq8NRHmo5I7U5AHZuUtpWV5bB8IafLcHvOEI7kdLfCH+uozp4 +Mm8qJWfuihGTvv7EOaik4VtHGamuC8n2dvoSTfr3hbezhXC32ifg4+2xAoGATbjp +snoKzheFHbPgZ8jFFJDKadOwcQ1Q19vMSJQGIa/08Ln5hWH6Qn/EZsTJPVnhGIiV +eZKEV6SbmZE9ho97xl1uvFWKmIkhu4+XVWSIF8iwwFGPwbgqJIOKvfVRtioujRbz +2AdLlSANCy9dCZtWHMnufccaRE7qqUfd/5TcwmECgYA5md1JZ+p+0yTkjNf4McCe +mpZiXXaM8uC3/XiWlfcw3ag88H0XCqt2UPcX03MC9ZzywfTdoEHhPRfr48uRiH7K +Pi1T8E2E0GkSfk6zh6wJm8Kmm9+VJd2rUFrb35xDMUfjpS228oH9UBQtwaAV+jmy +q/APoG+aWt5YIeYElx9mzA== +-----END PRIVATE KEY----- diff --git a/tests/fixtures/tls/server-dns.cnf b/tests/fixtures/tls/server-dns.cnf new file mode 100644 index 0000000..38300a5 --- /dev/null +++ b/tests/fixtures/tls/server-dns.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = dn +prompt = no +req_extensions = req_ext + +[dn] +CN = localhost + +[req_ext] +subjectAltName = @alt_names +extendedKeyUsage = serverAuth + +[alt_names] +DNS.1 = localhost diff --git a/tests/fixtures/tls/server-dns.crt b/tests/fixtures/tls/server-dns.crt new file mode 100644 index 0000000..4f8be31 --- /dev/null +++ b/tests/fixtures/tls/server-dns.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDLDCCAhSgAwIBAgIUU9TJAAXI1pavsowFKFye0t8gdz4wDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyNDAyMzA0 +N1oXDTM2MDMyMTAyMzA0N1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkq +hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzXLZ2ZeySuuzCDEPXsHMym0py7Tuwizu +5a4AKQt+7d8b51x22W+I6ZQXYOCUml+C0PGs5TfQDXs0VbYj08flVxw9732FPqzp +PhkZJlEbnu7+nEW4mb+aM+pnPnW7iLeCvePGXicYO9YctQNA9Onxe+63QtEi2OtO +24wlqreToA2qQ7UOeNbpcY5p9psgXt9wW+8rq1gXmKsci8DV//1MessDTTX4tgIh +lOvKBIqx2E2kIIbZ6BwB7veKCkwqrFeVKzmwCXfkVQZkyfxKTBOvf3VFssnFJ7Ld +dYuAMTpl9jCHkNykGNlc0a9ohglEvAjqwgpUdMcuzMMR8Kf0DjqZqwIDAQABo20w +azAUBgNVHREEDTALgglsb2NhbGhvc3QwEwYDVR0lBAwwCgYIKwYBBQUHAwEwHQYD +VR0OBBYEFOYSmMCX8XUYA9jMGY2M1v5/xaimMB8GA1UdIwQYMBaAFLYkwqnYXdk8 +BXOyWRlpBfARHDCcMA0GCSqGSIb3DQEBCwUAA4IBAQArYXsdMl+vPyTyv/oRzCo0 +mWyVO3RZPPRJSrrMz4UjtVMIfq6pnV+AXPEcWL/zEfQNcRPjsJeezkG66AZbI+ug +fur90YVDhWfOgde4E3cVhZz90aM/jcRMVvwNj0XiaX4JpxsVhv5T4LC80aXm0r02 +YfJyqwtNaZlKDntaW56q7nD5eaoqmYa+ogpdqwCIfGvManfRH6v6xmzgQRKnD1lc +LZPDZ9dmkQg2N/vdZfVUpB2+EZYF/9BuIvJyGKNYBjJiQGv2kbaFUl13mw7D/yGP +Zytpu5AIPl6ScYog8x6dSJMiYM+hO1bD9qOU0Kq48PaLQoQ4poYo7zdbiB8TZaKC +-----END CERTIFICATE----- diff --git a/tests/fixtures/tls/server-dns.key b/tests/fixtures/tls/server-dns.key new file mode 100644 index 0000000..09247a4 --- /dev/null +++ b/tests/fixtures/tls/server-dns.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDNctnZl7JK67MI +MQ9ewczKbSnLtO7CLO7lrgApC37t3xvnXHbZb4jplBdg4JSaX4LQ8azlN9ANezRV +tiPTx+VXHD3vfYU+rOk+GRkmURue7v6cRbiZv5oz6mc+dbuIt4K948ZeJxg71hy1 +A0D06fF77rdC0SLY607bjCWqt5OgDapDtQ541ulxjmn2myBe33Bb7yurWBeYqxyL +wNX//Ux6ywNNNfi2AiGU68oEirHYTaQghtnoHAHu94oKTCqsV5UrObAJd+RVBmTJ +/EpME69/dUWyycUnst11i4AxOmX2MIeQ3KQY2VzRr2iGCUS8COrCClR0xy7MwxHw +p/QOOpmrAgMBAAECggEAKOWyBCS0c0GUNA9AFgbSM4Gjjk+IL9MuAz6t/X2yWLvP +HDFF32bahFTcioZqToiwy9MwLbj8i5J5Co7lduUV/E887Q25lo5pJ9lrLjt7GhpN +SOKAKur/UVJaFw7ss/yD5DURafCyXEb1E/t/1ME1NwyAIqbrHu9IlV/Cp8c/dd6n +BUZHKN38Fs8riUOTW4BwAA1dV0k+g8bKb0byR5WmXbXqWG/XjfUKcI4Dg9ILsTwG +QgI/nzRyAV2unVcdNhM4pMTJ4U4x1Dh/rsWyQ4b206namkpfaIH7j8ovRY2OHhzT +EuCbn1Cpfmxp68hRrvMIhLvqVkXOSWzKOkHZba/qYQKBgQD0m3WJu5UHXip14gY/ +SmDm3vEcFjVZa+BmQtojs32ZfT/b1H5U+mfhXlVi0yf7B7nitQ9zOA7A2C8cxdtr +X9zd81MHLc7yUEOAlHBqSfBX/DyF1iRsdM7GYYxrsaECM7NIDukDR1Q/DvMZodRx +iVpRkNP3s87a1/QwvDKrwOPcCQKBgQDXBH1IZWcl1mrGgfAI7r5wXIJo0IrTlUcT +4nFSNv5YnrhnRzmLDRqxwlUs4II87nzsqv7OYryAg4n8us+X6XQv0hvTKcHsM7tl +vAO2EeQIER71tEg45HhmAXo2X2bE5kNMTr4COlb8k6R6r5K2F4yVMZim4qPcMYAK +tNRXf9JdEwKBgQDEjpOxvPmxdPrbxWfNvgAGJYpMTpBKLgShSAEwhRBdoacKCEQI +FzwYfoxQoGtVLk0yHtqudJJuZondLiT2sI60D85dS3MrhlHn5eA7mPS4Tyl3Rq/4 +Mxjhkwuakp9WPKNJOSoHB29sSKASrdcf8QaR2rZqKqQDeVtxOhnhqFuxuQKBgHrk +8/53BteXj/vZtKpGWs658UebOl3oinGREZgeGo3oWhmdmgQh/0nueuRlhcrxvLFA +otavlHIXvLyYwaJgKqpSetjcmxw4DTn+llhwLVd3Aa0J1+W8oBwdaA6/xGtx+LEa +qHt5gNJoSLBevYoaN53mdQudqm5mVHrKFDvWsRPFAoGBALAkJiXM4y+aTcwREjnO +CgebDRru37mprDW2PDa4vpTydxIcXZukTEZqXbVnXNDY/Gkyv06jYoY2R7mzbgfl +7RJIFAQffkzoeFadQZqimIAN3sQ6RRzNqmAlS2D6zjylX7N2rDiPcS3LtrJZHEyk +gTFx7+6gjfPNqPiAB5su0IJm +-----END PRIVATE KEY----- diff --git a/tests/fixtures/tls/server.cnf b/tests/fixtures/tls/server.cnf new file mode 100644 index 0000000..aaf9df5 --- /dev/null +++ b/tests/fixtures/tls/server.cnf @@ -0,0 +1,14 @@ +[req] +distinguished_name = dn +prompt = no +req_extensions = req_ext + +[dn] +CN = RTR Test Server + +[req_ext] +subjectAltName = @alt_names +extendedKeyUsage = serverAuth + +[alt_names] +IP.1 = 127.0.0.1 diff --git a/tests/fixtures/tls/server.crt b/tests/fixtures/tls/server.crt new file mode 100644 index 0000000..0ae330e --- /dev/null +++ b/tests/fixtures/tls/server.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDLTCCAhWgAwIBAgIUOiwVulOYZB53/kZ6zcqPl0p0g9QwDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDUw +MFoXDTM2MDMyMDA4MDUwMFowGjEYMBYGA1UEAwwPUlRSIFRlc3QgU2VydmVyMIIB +IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAn1qUW0bUWsq9pMNAzukncqyb +3sTbVcO+KfZfzBRwFfbWP1HmUziIsBq01b1m6Py6CL/DWBjgw4lJ8ZgxlhbjB9Zu +DVCietURIZIGMjuBiKOGW8vkCd3hlZqNph5sp2rUT2RBKym+orIApkt75epZB1cW +0KGsgnAhRpY6QAp4xalggtVp3HvUqapnVkRgix9AF3EAJJml+8QhGjgAn5tzFlOU +s5PjHeCDawTgfsAcoyRLi2E/OWAqJoEjujIPrkzfTC8d21yavQvIVeP+OyBRBwrw +kNZTyYy5LjO4tqZr8yLKdcMtIdAK2+7lEwWpR/LjFTYXwGs5GEb1Obon1ZdxEQID +AQABo2gwZjAPBgNVHREECDAGhwR/AAABMBMGA1UdJQQMMAoGCCsGAQUFBwMBMB0G +A1UdDgQWBBQQ7EAHYbMAQW6UdIcRjUmpMOyboTAfBgNVHSMEGDAWgBS2JMKp2F3Z +PAVzslkZaQXwERwwnDANBgkqhkiG9w0BAQsFAAOCAQEAffuJO/uKj23/3uFCt06+ +JT/8lozBSCifvileFjMZmNHViWT3cZuD2NM1/0uFNf9k95yyFn8yVeS4RiobX75L +rcqILZUesF5bkD7lmPdazR3Cz0BcCGo9+DvgWaX35lBOsuIMXRG3aUs6/x939zyc +p/KtDjqAgd26nhefJlZeb+Z9UYdmMXzcD5n8doWmNgYeaVQraBvMazenSmcwLzn6 +qzHjW5osUILBSl36KxuChHjHSa9aFPIrhoiGMR0w9oOGm6cHE3R4seNJNAX+YoZq +9PXgvVAHWvoSUuATY8iySNRiWUOmpybQuw3iqfgGighekZfxukLMSx2jq4WIquDg +Aw== +-----END CERTIFICATE----- diff --git a/tests/fixtures/tls/server.key b/tests/fixtures/tls/server.key new file mode 100644 index 0000000..01e8ea5 --- /dev/null +++ b/tests/fixtures/tls/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCfWpRbRtRayr2k +w0DO6SdyrJvexNtVw74p9l/MFHAV9tY/UeZTOIiwGrTVvWbo/LoIv8NYGODDiUnx +mDGWFuMH1m4NUKJ61REhkgYyO4GIo4Zby+QJ3eGVmo2mHmynatRPZEErKb6isgCm +S3vl6lkHVxbQoayCcCFGljpACnjFqWCC1Wnce9SpqmdWRGCLH0AXcQAkmaX7xCEa +OACfm3MWU5Szk+Md4INrBOB+wByjJEuLYT85YComgSO6Mg+uTN9MLx3bXJq9C8hV +4/47IFEHCvCQ1lPJjLkuM7i2pmvzIsp1wy0h0Arb7uUTBalH8uMVNhfAazkYRvU5 +uifVl3ERAgMBAAECggEADNbvCjW4SYJ5akujLqjBmW9H9diVtaDacyYbTOW+rD5h +v+NY9A6jkNDuDiS/JHmsgaT11+TVQ1wN/a3eHPJGI60G3ALJvKzrPvG1lxmNU9Wd +L0tL2fGrSfMUg7SC27BzX9w7lf88kX5XKA7/8iQCPWGqgG/uZuoi/D2BfwR4+7AL +0jPnElGK2PAHBWcEWvAQVfBRxPB96trF58+FLE+blQff2LY3AOR5pVRVwdK6V148 +4bT6tlPFcUabyVI70mqpw1FjoocEwMb8Fd7uOco5/0FulG3qe9DpZAjM2MWc4LRH +K3ZSI/QUp7CGEPcPeah3I/XHdqNrzRwKYj1nM3csoQKBgQDf/1NarvGaJNexjYHn +TZOZ6AOZZsP4MqdcNqNzV5K/Gmy1kZ5r7e5eH+SuZ1qH9jNLgNlusGWjmmrivYJo +qoiVDzE1tO5m3r08w/5187+TWuhUSOgObKFjtCS+GSR+YbGp/sxoWc5uXm0UZune +mWjtPRVwEXBCyrmIAGCGJGYMmQKBgQC2HuzGeJT1e+JxpgvDZKYU9mg0MDYQhejV +ibOUfSedSdw2P1vxQWAlQSPHbVpftMqXeM48b/B6IkzvwwSVNxV2I1agxFj0ZLTx +MQhDhet0CCFC0UBT9/Ch6JWQ4FbrM5OaNhLVcfeksc8B7cTYZB/lMeSE2r1GH8t1 +HNqj8+6bOQKBgGwOYKiLYlOI2GB3siXh34VMTogu8fSGgwPR+9GFem4kEjMY10Kb +mfTgD9IuW5bhJueSddGW2MEumcddwk45jf/SP1v4N1V6t/FbXyKJfm5YWWFndkKX +FtfhLCRkPp2VBT7LgtIIGLRXaul/p+xRNzPS1sekMfKWlx/LhsTPREdpAoGBAKZY +UMRnVwdyBD7x/0SVJe13s24Xqwokhaqlf9VdC1XrJKyX6o7Nu9fLS7bX9vf71h/M +Q/OH+wpTUhqc8g6opX2mgXWOYgG4Cl1S/81NAOaWlmrFXhBUIwJ/wjz16+4gyezM +/x7eXeecUQvd9TIBIfDiRWvjr4XhfKCXnkyqfYJhAoGAGMUedwQLN11NLclbTiVo ++xgy1yMeZV8c49+mocwwLgoxINBwygSU3klZZ4Bg+whzaVVlHccrHtEhw+73CbiM +xmfIuD0C/U/0FIP2kPm/g8rtH1QVE6Rkshwzzf+txeaGSlXvVtq4zjlVbDi+UrYk +qhSRT6NNKMROWj3Mx3CIV/A= +-----END PRIVATE KEY----- diff --git a/tests/test_cache.rs b/tests/test_cache.rs index 013c9f5..46de18a 100644 --- a/tests/test_cache.rs +++ b/tests/test_cache.rs @@ -1,4 +1,4 @@ -mod common; +mod common; use std::collections::VecDeque; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -9,9 +9,14 @@ use common::test_helper::{ serial_result_detail_to_string, snapshot_hashes_to_string, test_report, v4_origin, v6_origin, }; -use rpki::rtr::cache::{Delta, RtrCacheBuilder, SerialResult, Snapshot}; -use rpki::rtr::payload::{Payload, Timing}; -use rpki::rtr::store_db::RtrStore; +use rpki::data_model::resources::as_resources::Asn; +use rpki::rtr::cache::{ + CacheAvailability, Delta, RtrCacheBuilder, SerialResult, SessionIds, Snapshot, + validate_payload_updates_for_rtr, + validate_payloads_for_rtr, +}; +use rpki::rtr::payload::{Aspa, Payload, Timing}; +use rpki::rtr::store::RtrStore; fn delta_to_string(delta: &Delta) -> String { format!( @@ -35,15 +40,10 @@ fn deltas_window_to_string(deltas: &VecDeque>) -> String { out } -fn get_deltas_since_input_to_string( - cache_session_id: u16, - cache_serial: u32, - client_session: u16, - client_serial: u32, -) -> String { +fn get_deltas_since_input_to_string(cache_session_id: u16, cache_serial: u32, client_serial: u32) -> String { format!( - "cache.session_id: {}\ncache.serial: {}\nclient_session: {}\nclient_serial: {}\n", - cache_session_id, cache_serial, client_session, client_serial + "cache.session_id: {}\ncache.serial: {}\nclient_serial: {}\n", + cache_session_id, cache_serial, client_serial ) } @@ -56,6 +56,8 @@ fn snapshot_hashes_and_sorted_view_to_string(snapshot: &Snapshot) -> String { ) } +/// Snapshot ?hash ? +/// payload snapshot_hash / origins_hash ? #[test] fn snapshot_hash_is_stable_for_same_content_with_different_input_order() { let a = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -82,13 +84,13 @@ fn snapshot_hash_is_stable_for_same_content_with_different_input_order() { let s2 = Snapshot::from_payloads(s2_input.clone()); let input = format!( - "s1 原始输入 payloads:\n{}\ns2 原始输入 payloads:\n{}", + "s1 payloads:\n{}\ns2 payloads:\n{}", indent_block(&payloads_to_string(&s1_input), 2), indent_block(&payloads_to_string(&s2_input), 2), ); let output = format!( - "s1:\n{}\ns2:\n{}\n结论:\n same_content: {}\n same_origins: {}\n snapshot_hash 相同: {}\n origins_hash 相同: {}\n", + "s1:\n{}\ns2:\n{}\n:\n same_content: {}\n same_origins: {}\n snapshot_hash : {}\n origins_hash : {}\n", indent_block(&snapshot_hashes_and_sorted_view_to_string(&s1), 2), indent_block(&snapshot_hashes_and_sorted_view_to_string(&s2), 2), s1.same_content(&s2), @@ -99,7 +101,7 @@ fn snapshot_hash_is_stable_for_same_content_with_different_input_order() { test_report( "snapshot_hash_is_stable_for_same_content_with_different_input_order", - "验证相同语义内容即使原始输入顺序不同,Snapshot 的 hash 仍然稳定一致。", + "test purpose", &input, &output, ); @@ -110,6 +112,96 @@ fn snapshot_hash_is_stable_for_same_content_with_different_input_order() { assert_eq!(s1.origins_hash(), s2.origins_hash()); } +#[tokio::test] +async fn init_keeps_cache_running_when_file_loader_returns_no_data() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let cache = rpki::rtr::cache::RtrCache::default() + .init(&store, 16, Timing::new(600, 600, 7200), || Ok(vec![])) + .unwrap(); + + assert!(!cache.is_data_available()); + assert_eq!(cache.serial(), 0); + assert!(cache.snapshot().payloads_for_rtr().is_empty()); +} + +#[tokio::test] +async fn init_restores_wraparound_delta_window_from_store() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + let session_ids = SessionIds::from_array([42, 43, 44]); + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)), + Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498)), + ]); + + let d_max = Delta::new( + u32::MAX, + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + vec![], + ); + let d_zero = Delta::new( + 0, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + ); + let d_one = Delta::new( + 1, + vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![], + ); + + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + u32::MAX, + Some(&d_max), + Some((u32::MAX, u32::MAX)), + false, + ) + .unwrap(); + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 0, + Some(&d_zero), + Some((u32::MAX, 0)), + false, + ) + .unwrap(); + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 1, + Some(&d_one), + Some((u32::MAX, 1)), + false, + ) + .unwrap(); + + let cache = rpki::rtr::cache::RtrCache::default() + .init(&store, 16, Timing::new(600, 600, 7200), || Ok(Vec::new())) + .unwrap(); + + match cache.get_deltas_since(u32::MAX.wrapping_sub(1)) { + SerialResult::Delta(delta) => { + assert_eq!(delta.serial(), 1); + assert_eq!(delta.announced().len(), 3); + } + _ => panic!("expected wraparound delta to be restored from store"), + } +} + +/// Snapshot::diff() ? +/// old_snapshot ?new_snapshot announced?withdrawn? #[test] fn snapshot_diff_reports_announced_and_withdrawn_correctly() { let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -136,7 +228,7 @@ fn snapshot_diff_reports_announced_and_withdrawn_correctly() { let (announced, withdrawn) = old_snapshot.diff(&new_snapshot); let input = format!( - "old_snapshot 原始输入:\n{}\nnew_snapshot 原始输入:\n{}", + "old_snapshot :\n{}\nnew_snapshot :\n{}", indent_block(&payloads_to_string(&old_input), 2), indent_block(&payloads_to_string(&new_input), 2), ); @@ -149,7 +241,7 @@ fn snapshot_diff_reports_announced_and_withdrawn_correctly() { test_report( "snapshot_diff_reports_announced_and_withdrawn_correctly", - "验证 diff() 能正确找出 announced 和 withdrawn 的 payload。", + "test purpose", &input, &output, ); @@ -168,6 +260,8 @@ fn snapshot_diff_reports_announced_and_withdrawn_correctly() { } } +/// Snapshot::payloads_for_rtr() ? +/// IPv4 Prefix IPv6 Prefix ?IPv4 announcement ?RTR ? #[test] fn snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_descending() { let v4_low = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -189,18 +283,18 @@ fn snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_desce let output_payloads = snapshot.payloads_for_rtr(); let input = format!( - "原始输入 payloads(构造 Snapshot 前):\n{}", + " payloads?Snapshot :\n{}", indent_block(&payloads_to_string(&input_payloads), 2), ); let output = format!( - "排序后 payloads_for_rtr:\n{}", + "?payloads_for_rtr:\n{}", indent_block(&payloads_to_string(&output_payloads), 2), ); test_report( "snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_descending", - "验证 Snapshot::payloads_for_rtr() 会按 RTR 规则排序:IPv4 在 IPv6 前,且 IPv4 announcement 按地址降序。", + "test purpose", &input, &output, ); @@ -227,6 +321,8 @@ fn snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_desce ); } +/// Delta::new() ?announced ?withdrawn? +/// announced announcement ithdrawn withdrawal ? #[test] fn delta_new_sorts_announced_descending_and_withdrawn_ascending() { let announced_low = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -246,7 +342,7 @@ fn delta_new_sorts_announced_descending_and_withdrawn_ascending() { let delta = Delta::new(101, input_announced.clone(), input_withdrawn.clone()); let input = format!( - "announced(构造前):\n{}withdrawn(构造前):\n{}", + "announced?\n{}withdrawn?\n{}", indent_block(&payloads_to_string(&input_announced), 2), indent_block(&payloads_to_string(&input_withdrawn), 2), ); @@ -255,7 +351,7 @@ fn delta_new_sorts_announced_descending_and_withdrawn_ascending() { test_report( "delta_new_sorts_announced_descending_and_withdrawn_ascending", - "验证 Delta::new() 会自动排序:announced 按 RTR announcement 规则,withdrawn 按 RTR withdrawal 规则。", + "test purpose", &input, &output, ); @@ -284,22 +380,24 @@ fn delta_new_sorts_announced_descending_and_withdrawn_ascending() { ); } +/// serial ?serial +/// get_deltas_since() ?UpToDate?Delta ?ResetRequired? #[test] fn get_deltas_since_returns_up_to_date_when_client_serial_matches_current() { let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::default()) .build(); - let result = cache.get_deltas_since(42, 100); + let result = cache.get_deltas_since(100); - let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 100); + let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100); let output = serial_result_detail_to_string(&result); test_report( "get_deltas_since_returns_up_to_date_when_client_serial_matches_current", - "验证当客户端 serial 与缓存当前 serial 相同,返回 UpToDate。", + "test purpose", &input, &output, ); @@ -310,33 +408,7 @@ fn get_deltas_since_returns_up_to_date_when_client_serial_matches_current() { } } -#[test] -fn get_deltas_since_returns_reset_required_on_session_mismatch() { - let cache = RtrCacheBuilder::new() - .session_id(42) - .serial(100) - .timing(Timing::default()) - .build(); - - let result = cache.get_deltas_since(999, 100); - - let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 999, 100); - let output = serial_result_detail_to_string(&result); - - test_report( - "get_deltas_since_returns_reset_required_on_session_mismatch", - "验证当客户端 session_id 与缓存 session_id 不一致时,返回 ResetRequired。", - &input, - &output, - ); - - match result { - SerialResult::ResetRequired => {} - _ => panic!("expected ResetRequired"), - } -} - -#[test] +/// serial delta window ?/// get_deltas_since() ?ResetRequired?#[test] fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old() { let d1 = Arc::new(Delta::new( 101, @@ -354,24 +426,24 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old() { deltas.push_back(d2); let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(102) .timing(Timing::default()) .deltas(deltas.clone()) .build(); - let result = cache.get_deltas_since(42, 99); + let result = cache.get_deltas_since(99); let input = format!( "{}delta_window:\n{}", - get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 99), + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 99), indent_block(&deltas_window_to_string(&deltas), 2), ); let output = serial_result_detail_to_string(&result); test_report( "get_deltas_since_returns_reset_required_when_client_serial_is_too_old", - "验证当客户端 serial 太旧,已超出 delta window 覆盖范围时,返回 ResetRequired。", + "test purpose", &input, &output, ); @@ -382,8 +454,10 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old() { } } +/// serial delta window ? +/// get_deltas_since() ?delta ? #[test] -fn get_deltas_since_returns_applicable_deltas_in_order() { +fn get_deltas_since_returns_minimal_merged_delta() { let d1 = Arc::new(Delta::new( 101, vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], @@ -400,60 +474,81 @@ fn get_deltas_since_returns_applicable_deltas_in_order() { vec![], )); + let final_snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)), + Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498)), + ]); + let mut deltas = VecDeque::new(); deltas.push_back(d1); deltas.push_back(d2); deltas.push_back(d3); let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(103) .timing(Timing::default()) + .snapshot(final_snapshot) .deltas(deltas.clone()) .build(); - let result = cache.get_deltas_since(42, 101); + let result = cache.get_deltas_since(101); let input = format!( "{}delta_window:\n{}", - get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 101), + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 101), indent_block(&deltas_window_to_string(&deltas), 2), ); let output = serial_result_detail_to_string(&result); test_report( - "get_deltas_since_returns_applicable_deltas_in_order", - "验证当客户端 serial 在 delta window 内时,返回正确且有序的 deltas。", + "get_deltas_since_returns_minimal_merged_delta", + "test purpose", &input, &output, ); match result { - SerialResult::Deltas(result) => { - assert_eq!(result.len(), 2); - assert_eq!(result[0].serial(), 102); - assert_eq!(result[1].serial(), 103); + SerialResult::Delta(delta) => { + assert_eq!(delta.serial(), 103); + assert_eq!(delta.announced().len(), 2); + assert_eq!(delta.withdrawn().len(), 0); + + let a0 = as_v4_route_origin(&delta.announced()[0]); + let a1 = as_v4_route_origin(&delta.announced()[1]); + + assert_eq!( + a0.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(203, 0, 113, 0)) + ); + assert_eq!( + a1.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(198, 51, 100, 0)) + ); } - _ => panic!("expected Deltas"), + _ => panic!("expected Delta"), } } +/// serial serial +/// ResetRequired? #[test] fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future() { let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::default()) .build(); - let result = cache.get_deltas_since(42, 101); + let result = cache.get_deltas_since(101); - let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 101); + let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 101); let output = serial_result_detail_to_string(&result); test_report( "get_deltas_since_returns_reset_required_when_client_serial_is_in_future", - "验证当客户端 serial 比缓存当前 serial 还大时,返回 ResetRequired。", + "test purpose", &input, &output, ); @@ -464,6 +559,161 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future() { } } +#[test] +fn get_deltas_since_supports_incremental_updates_across_serial_wraparound() { + let a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + + let d_max = Arc::new(Delta::new( + u32::MAX, + vec![Payload::RouteOrigin(a.clone())], + vec![], + )); + let d_zero = Arc::new(Delta::new( + 0, + vec![Payload::RouteOrigin(b.clone())], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d_max); + deltas.push_back(d_zero); + + let final_snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(a.clone()), + Payload::RouteOrigin(b.clone()), + ]); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(0) + .timing(Timing::default()) + .snapshot(final_snapshot) + .deltas(deltas.clone()) + .build(); + + let result = cache.get_deltas_since(u32::MAX.wrapping_sub(1)); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), u32::MAX.wrapping_sub(1)), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_supports_incremental_updates_across_serial_wraparound", + "test purpose", + &input, + &output, + ); + + match result { + SerialResult::Delta(delta) => { + assert_eq!(delta.serial(), 0); + assert_eq!(delta.announced().len(), 2); + assert_eq!(delta.withdrawn().len(), 0); + + match &delta.announced()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &b), + _ => panic!("expected announced RouteOrigin"), + } + + match &delta.announced()[1] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &a), + _ => panic!("expected announced RouteOrigin"), + } + } + _ => panic!("expected Delta"), + } +} + +#[test] +fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old_across_wraparound() { + let d_max = Arc::new(Delta::new( + u32::MAX, + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + vec![], + )); + let d_zero = Arc::new(Delta::new( + 0, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + )); + let d_one = Arc::new(Delta::new( + 1, + vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d_max); + deltas.push_back(d_zero); + deltas.push_back(d_one); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(1) + .timing(Timing::default()) + .snapshot(Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)), + Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498)), + ])) + .deltas(deltas.clone()) + .build(); + + let client_serial = u32::MAX.wrapping_sub(2); + let result = cache.get_deltas_since(client_serial); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), client_serial), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_returns_reset_required_when_client_serial_is_too_old_across_wraparound", + "test purpose", + &input, + &output, + ); + + match result { + SerialResult::ResetRequired => {} + _ => panic!("expected ResetRequired"), + } +} + +#[test] +fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future_across_wraparound() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(u32::MAX) + .timing(Timing::default()) + .build(); + + let result = cache.get_deltas_since(0); + + let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 0); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_returns_reset_required_when_client_serial_is_in_future_across_wraparound", + "test purpose", + &input, + &output, + ); + + match result { + SerialResult::ResetRequired => {} + _ => panic!("expected ResetRequired"), + } +} + +/// update() ?payload ? +/// serial delta? #[tokio::test] async fn update_no_change_keeps_serial_and_produces_no_delta() { let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -476,7 +726,7 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() { let snapshot = Snapshot::from_payloads(old_input.clone()); let mut cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::default()) .snapshot(snapshot.clone()) @@ -493,16 +743,16 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() { cache.update(new_payloads.clone(), &store).unwrap(); let current_snapshot = cache.snapshot(); - let result = cache.get_deltas_since(42, 100); + let result = cache.get_deltas_since(100); let input = format!( - "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + "old_snapshot :\n{}new_payloads :\n{}", indent_block(&payloads_to_string(&old_input), 2), indent_block(&payloads_to_string(&new_payloads), 2), ); let output = format!( - "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), indent_block(&serial_result_detail_to_string(&result), 2), @@ -510,7 +760,7 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() { test_report( "update_no_change_keeps_serial_and_produces_no_delta", - "验证 update() 在新旧内容完全相同时:serial 不变、snapshot 不变、不会产生新的 delta。", + "test purpose", &input, &output, ); @@ -524,6 +774,8 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() { } } +/// update() payload ? +/// ?serial?announced ? #[tokio::test] async fn update_add_only_increments_serial_and_generates_announced_delta() { let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -533,7 +785,7 @@ async fn update_add_only_increments_serial_and_generates_announced_delta() { let old_snapshot = Snapshot::from_payloads(old_input.clone()); let mut cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::default()) .snapshot(old_snapshot.clone()) @@ -550,16 +802,16 @@ async fn update_add_only_increments_serial_and_generates_announced_delta() { cache.update(new_payloads.clone(), &store).unwrap(); let current_snapshot = cache.snapshot(); - let result = cache.get_deltas_since(42, 100); + let result = cache.get_deltas_since(100); let input = format!( - "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + "old_snapshot :\n{}new_payloads :\n{}", indent_block(&payloads_to_string(&old_input), 2), indent_block(&payloads_to_string(&new_payloads), 2), ); let output = format!( - "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), indent_block(&serial_result_detail_to_string(&result), 2), @@ -567,7 +819,7 @@ async fn update_add_only_increments_serial_and_generates_announced_delta() { test_report( "update_add_only_increments_serial_and_generates_announced_delta", - "验证 update() 在只新增 payload 时:serial + 1,delta 中只有 announced,withdrawn 为空。", + "test purpose", &input, &output, ); @@ -575,10 +827,7 @@ async fn update_add_only_increments_serial_and_generates_announced_delta() { assert_eq!(cache.serial(), 101); match result { - SerialResult::Deltas(deltas) => { - assert_eq!(deltas.len(), 1); - let delta = &deltas[0]; - + SerialResult::Delta(delta) => { assert_eq!(delta.serial(), 101); assert_eq!(delta.announced().len(), 1); assert_eq!(delta.withdrawn().len(), 0); @@ -588,10 +837,12 @@ async fn update_add_only_increments_serial_and_generates_announced_delta() { _ => panic!("expected announced RouteOrigin"), } } - _ => panic!("expected Deltas"), + _ => panic!("expected Delta"), } } +/// update() payload ? +/// ?serial?withdrawn ? #[tokio::test] async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -604,7 +855,7 @@ async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { let old_snapshot = Snapshot::from_payloads(old_input.clone()); let mut cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::default()) .snapshot(old_snapshot.clone()) @@ -618,16 +869,16 @@ async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { cache.update(new_payloads.clone(), &store).unwrap(); let current_snapshot = cache.snapshot(); - let result = cache.get_deltas_since(42, 100); + let result = cache.get_deltas_since(100); let input = format!( - "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + "old_snapshot :\n{}new_payloads :\n{}", indent_block(&payloads_to_string(&old_input), 2), indent_block(&payloads_to_string(&new_payloads), 2), ); let output = format!( - "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), indent_block(&serial_result_detail_to_string(&result), 2), @@ -635,7 +886,7 @@ async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { test_report( "update_remove_only_increments_serial_and_generates_withdrawn_delta", - "验证 update() 在只删除 payload 时:serial + 1,delta 中只有 withdrawn,announced 为空。", + "test purpose", &input, &output, ); @@ -643,10 +894,7 @@ async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { assert_eq!(cache.serial(), 101); match result { - SerialResult::Deltas(deltas) => { - assert_eq!(deltas.len(), 1); - let delta = &deltas[0]; - + SerialResult::Delta(delta) => { assert_eq!(delta.serial(), 101); assert_eq!(delta.announced().len(), 0); assert_eq!(delta.withdrawn().len(), 1); @@ -656,10 +904,12 @@ async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { _ => panic!("expected withdrawn RouteOrigin"), } } - _ => panic!("expected Deltas"), + _ => panic!("expected Delta"), } } +/// update() payload +/// ?serial announced ?withdrawn? #[tokio::test] async fn update_add_and_remove_increments_serial_and_generates_both_sides() { let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); @@ -678,7 +928,7 @@ async fn update_add_and_remove_increments_serial_and_generates_both_sides() { let old_snapshot = Snapshot::from_payloads(old_input.clone()); let mut cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::default()) .snapshot(old_snapshot.clone()) @@ -695,16 +945,16 @@ async fn update_add_and_remove_increments_serial_and_generates_both_sides() { cache.update(new_payloads.clone(), &store).unwrap(); let current_snapshot = cache.snapshot(); - let result = cache.get_deltas_since(42, 100); + let result = cache.get_deltas_since(100); let input = format!( - "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + "old_snapshot :\n{}new_payloads :\n{}", indent_block(&payloads_to_string(&old_input), 2), indent_block(&payloads_to_string(&new_payloads), 2), ); let output = format!( - "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), indent_block(&serial_result_detail_to_string(&result), 2), @@ -712,7 +962,7 @@ async fn update_add_and_remove_increments_serial_and_generates_both_sides() { test_report( "update_add_and_remove_increments_serial_and_generates_both_sides", - "验证 update() 在同时新增和删除 payload 时:serial + 1,delta 中 announced 和 withdrawn 都正确。", + "test purpose", &input, &output, ); @@ -720,10 +970,7 @@ async fn update_add_and_remove_increments_serial_and_generates_both_sides() { assert_eq!(cache.serial(), 101); match result { - SerialResult::Deltas(deltas) => { - assert_eq!(deltas.len(), 1); - let delta = &deltas[0]; - + SerialResult::Delta(delta) => { assert_eq!(delta.serial(), 101); assert_eq!(delta.announced().len(), 1); assert_eq!(delta.withdrawn().len(), 1); @@ -738,6 +985,374 @@ async fn update_add_and_remove_increments_serial_and_generates_both_sides() { _ => panic!("expected withdrawn RouteOrigin"), } } - _ => panic!("expected Deltas"), + _ => panic!("expected Delta"), } -} \ No newline at end of file +} + +/// ?prefix announce withdraw +/// ?UpToDate? +#[test] +fn get_deltas_since_cancels_announce_then_withdraw_for_same_prefix() { + let a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + + let d1 = Arc::new(Delta::new( + 101, + vec![Payload::RouteOrigin(a.clone())], + vec![], + )); + let d2 = Arc::new(Delta::new( + 102, + vec![], + vec![Payload::RouteOrigin(a.clone())], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d1); + deltas.push_back(d2); + + // A ? + let final_snapshot = Snapshot::empty(); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(102) + .timing(Timing::default()) + .snapshot(final_snapshot) + .deltas(deltas.clone()) + .build(); + + let result = cache.get_deltas_since(100); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_cancels_announce_then_withdraw_for_same_prefix", + "test purpose", + &input, + &output, + ); + + match result { + SerialResult::UpToDate => {} + _ => panic!("expected UpToDate"), + } +} + +/// ?prefix withdraw announce ? +/// ?UpToDate? +#[test] +fn get_deltas_since_cancels_withdraw_then_announce_for_same_prefix() { + let a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + + let d1 = Arc::new(Delta::new( + 101, + vec![], + vec![Payload::RouteOrigin(a.clone())], + )); + let d2 = Arc::new(Delta::new( + 102, + vec![Payload::RouteOrigin(a.clone())], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d1); + deltas.push_back(d2); + + // ?A? + let final_snapshot = Snapshot::from_payloads(vec![Payload::RouteOrigin(a.clone())]); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(102) + .timing(Timing::default()) + .snapshot(final_snapshot) + .deltas(deltas.clone()) + .build(); + + let result = cache.get_deltas_since(100); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_cancels_withdraw_then_announce_for_same_prefix", + "test purpose", + &input, + &output, + ); + + match result { + SerialResult::UpToDate => {} + _ => panic!("expected UpToDate"), + } +} + +/// ?A ?B +/// ?delta?withdraw A + announce B? +#[test] +fn get_deltas_since_merges_replacement_into_withdraw_and_announce() { + let a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let b = v4_origin(192, 0, 2, 0, 24, 25, 64496); + + let d1 = Arc::new(Delta::new( + 101, + vec![], + vec![Payload::RouteOrigin(a.clone())], + )); + let d2 = Arc::new(Delta::new( + 102, + vec![Payload::RouteOrigin(b.clone())], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d1); + deltas.push_back(d2); + + let final_snapshot = Snapshot::from_payloads(vec![Payload::RouteOrigin(b.clone())]); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(102) + .timing(Timing::default()) + .snapshot(final_snapshot) + .deltas(deltas.clone()) + .build(); + + let result = cache.get_deltas_since(100); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_merges_replacement_into_withdraw_and_announce", + "test purpose", + &input, + &output, + ); + + match result { + SerialResult::Delta(delta) => { + assert_eq!(delta.serial(), 102); + assert_eq!(delta.announced().len(), 1); + assert_eq!(delta.withdrawn().len(), 1); + + match &delta.withdrawn()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &a), + _ => panic!("expected withdrawn RouteOrigin"), + } + + match &delta.announced()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &b), + _ => panic!("expected announced RouteOrigin"), + } + } + _ => panic!("expected Delta"), + } +} + +/// delta ? +/// ? +#[test] +fn get_deltas_since_merges_multiple_deltas_to_final_minimal_view() { + let a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + let c = v4_origin(203, 0, 113, 0, 24, 24, 64498); + + // 100 -> 101 : +A + let d1 = Arc::new(Delta::new( + 101, + vec![Payload::RouteOrigin(a.clone())], + vec![], + )); + // 101 -> 102 : -A +B + let d2 = Arc::new(Delta::new( + 102, + vec![Payload::RouteOrigin(b.clone())], + vec![Payload::RouteOrigin(a.clone())], + )); + // 102 -> 103 : -B +C + let d3 = Arc::new(Delta::new( + 103, + vec![Payload::RouteOrigin(c.clone())], + vec![Payload::RouteOrigin(b.clone())], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d1); + deltas.push_back(d2); + deltas.push_back(d3); + + let final_snapshot = Snapshot::from_payloads(vec![Payload::RouteOrigin(c.clone())]); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(103) + .timing(Timing::default()) + .snapshot(final_snapshot) + .deltas(deltas.clone()) + .build(); + + // ?serial=100 A/B ?+C + let result = cache.get_deltas_since(100); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_merges_multiple_deltas_to_final_minimal_view", + "test purpose", + &input, + &output, + ); + + match result { + SerialResult::Delta(delta) => { + assert_eq!(delta.serial(), 103); + assert_eq!(delta.announced().len(), 1); + assert_eq!(delta.withdrawn().len(), 0); + + match &delta.announced()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &c), + _ => panic!("expected announced RouteOrigin"), + } + } + _ => panic!("expected Delta"), + } +} + +#[test] +fn snapshot_from_payloads_unions_aspas_by_customer() { + let first = Payload::Aspa(Aspa::new( + Asn::from(64496u32), + vec![Asn::from(64497u32)], + )); + let second = Payload::Aspa(Aspa::new( + Asn::from(64496u32), + vec![Asn::from(64498u32), Asn::from(64497u32)], + )); + + let snapshot = Snapshot::from_payloads(vec![first, second]); + let aspas = snapshot.aspas().iter().collect::>(); + + assert_eq!(aspas.len(), 1); + assert_eq!(aspas[0].customer_asn(), Asn::from(64496u32)); + assert_eq!( + aspas[0].provider_asns(), + &[Asn::from(64497u32), Asn::from(64498u32)] + ); +} + +#[test] +fn snapshot_diff_replaces_aspa_with_single_announcement() { + let old_snapshot = Snapshot::from_payloads(vec![Payload::Aspa(Aspa::new( + Asn::from(64496u32), + vec![Asn::from(64497u32)], + ))]); + let new_snapshot = Snapshot::from_payloads(vec![Payload::Aspa(Aspa::new( + Asn::from(64496u32), + vec![Asn::from(64498u32)], + ))]); + + let (announced, withdrawn) = old_snapshot.diff(&new_snapshot); + + assert_eq!(announced.len(), 1); + assert!(withdrawn.is_empty()); + + match &announced[0] { + Payload::Aspa(aspa) => { + assert_eq!(aspa.customer_asn(), Asn::from(64496u32)); + assert_eq!(aspa.provider_asns(), &[Asn::from(64498u32)]); + } + _ => panic!("expected announced ASPA"), + } +} + +#[test] +fn get_deltas_since_merges_aspa_replacement_into_single_announcement() { + let old = Aspa::new(Asn::from(64496u32), vec![Asn::from(64497u32)]); + let new = Aspa::new(Asn::from(64496u32), vec![Asn::from(64498u32)]); + + let d1 = Arc::new(Delta::new(101, vec![], vec![Payload::Aspa(old.clone())])); + let d2 = Arc::new(Delta::new(102, vec![Payload::Aspa(new.clone())], vec![])); + + let mut deltas = VecDeque::new(); + deltas.push_back(d1); + deltas.push_back(d2); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(102) + .timing(Timing::default()) + .snapshot(Snapshot::from_payloads(vec![Payload::Aspa(new.clone())])) + .deltas(deltas) + .build(); + + let result = cache.get_deltas_since(100); + + match result { + SerialResult::Delta(delta) => { + assert_eq!(delta.announced().len(), 1); + assert!(delta.withdrawn().is_empty()); + + match &delta.announced()[0] { + Payload::Aspa(aspa) => assert_eq!(aspa, &new), + _ => panic!("expected announced ASPA"), + } + } + _ => panic!("expected Delta"), + } +} + +#[test] +fn validate_payloads_for_rtr_rejects_unsorted_snapshot_payloads() { + let low = Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)); + let high = Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)); + + let err = validate_payloads_for_rtr(&[low, high], true).unwrap_err(); + assert!(err + .to_string() + .contains("RTR payload ordering violation")); +} + +#[test] +fn validate_payload_updates_for_rtr_rejects_unsorted_aspa_updates() { + let withdraw = ( + false, + Payload::Aspa(Aspa::new( + Asn::from(64497u32), + vec![Asn::from(64500u32)], + )), + ); + let announce = ( + true, + Payload::Aspa(Aspa::new( + Asn::from(64496u32), + vec![Asn::from(64499u32)], + )), + ); + + let err = validate_payload_updates_for_rtr(&[withdraw, announce]).unwrap_err(); + assert!(err.to_string().contains("withdraw ASPA")); + assert!(err.to_string().contains("announce ASPA")); +} + + diff --git a/tests/test_pdu.rs b/tests/test_pdu.rs new file mode 100644 index 0000000..06c2f87 --- /dev/null +++ b/tests/test_pdu.rs @@ -0,0 +1,204 @@ +use std::net::Ipv4Addr; + +use tokio::io::{duplex, AsyncWriteExt}; + +use rpki::data_model::resources::as_resources::Asn; +use rpki::rtr::error_type::ErrorCode; +use rpki::rtr::payload::{Aspa as PayloadAspa, Ski, Timing}; +use rpki::rtr::pdu::{ + Aspa, EndOfDataV1, ErrorReport, Flags, Header, IPv4Prefix, RouterKey, SerialNotify, + END_OF_DATA_V1_LEN, MAX_PDU_LEN, +}; + +const ERROR_REPORT_FIXED_PART_LEN: usize = 16; + +#[tokio::test] +async fn serial_notify_roundtrip() { + let (mut client, mut server) = duplex(1024); + let original = SerialNotify::new(1, 42, 100); + + tokio::spawn(async move { + original.write(&mut client).await.unwrap(); + }); + + let decoded = SerialNotify::read(&mut server).await.unwrap(); + + assert_eq!(decoded.version(), 1); + assert_eq!(decoded.session_id(), 42); + assert_eq!(decoded.serial_number(), 100); +} + +#[tokio::test] +async fn ipv4_prefix_roundtrip() { + let (mut client, mut server) = duplex(1024); + let prefix = IPv4Prefix::new( + 1, + Flags::new(1), + 24, + 24, + Ipv4Addr::new(192, 168, 0, 0), + 65000u32.into(), + ); + + tokio::spawn(async move { + prefix.write(&mut client).await.unwrap(); + }); + + let decoded = IPv4Prefix::read(&mut server).await.unwrap(); + + assert_eq!(decoded.prefix_len(), 24); + assert_eq!(decoded.max_len(), 24); + assert_eq!(decoded.prefix(), Ipv4Addr::new(192, 168, 0, 0)); + assert!(decoded.flag().is_announce()); +} + +#[test] +fn error_report_truncates_large_erroneous_pdu() { + let pdu = vec![0xAA; MAX_PDU_LEN as usize]; + let text = b"details"; + + let report = ErrorReport::new(1, ErrorCode::CorruptData.as_u16(), &pdu, text); + + assert_eq!(report.as_ref().len(), MAX_PDU_LEN as usize); + assert_eq!( + report.erroneous_pdu(), + &pdu[..(MAX_PDU_LEN as usize - ERROR_REPORT_FIXED_PART_LEN)] + ); + assert!(report.text().is_empty()); +} + +#[test] +fn error_report_truncates_text_to_fit() { + let pdu = [1, 2, 3, 4]; + let text = vec![b'x'; MAX_PDU_LEN as usize]; + + let report = ErrorReport::new(1, ErrorCode::CorruptData.as_u16(), pdu, &text); + + assert_eq!(report.erroneous_pdu(), pdu); + assert_eq!(report.as_ref().len(), MAX_PDU_LEN as usize); + assert_eq!( + report.text().len(), + MAX_PDU_LEN as usize - ERROR_REPORT_FIXED_PART_LEN - pdu.len() + ); +} + +#[tokio::test] +async fn error_report_rejects_non_utf8_text() { + let (mut client, mut server) = duplex(1024); + let header = Header::new(1, ErrorReport::PDU, ErrorCode::CorruptData.as_u16(), 17); + let mut bytes = Vec::from(header.as_ref()); + bytes.extend_from_slice(&0u32.to_be_bytes()); + bytes.extend_from_slice(&1u32.to_be_bytes()); + bytes.push(0xFF); + + client.write_all(&bytes).await.unwrap(); + + let err = ErrorReport::read(&mut server).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); +} + +#[tokio::test] +async fn router_key_length_matches_wire_size() { + let ski = Ski::default(); + let spki: std::sync::Arc<[u8]> = std::sync::Arc::from(vec![1u8; 32]); + let pdu = RouterKey::new(1, Flags::new(1), ski, Asn::from(64496u32), spki); + let (mut client, mut server) = duplex(1024); + + tokio::spawn(async move { + pdu.write(&mut client).await.unwrap(); + }); + + let header = Header::read(&mut server).await.unwrap(); + assert_eq!(header.pdu(), RouterKey::PDU); + assert_eq!(header.length(), 8 + 20 + 4 + 32); +} + +#[tokio::test] +async fn aspa_length_matches_wire_size() { + let pdu = Aspa::new(2, Flags::new(1), 64496, vec![64497, 64498]); + let (mut client, mut server) = duplex(1024); + + tokio::spawn(async move { + pdu.write(&mut client).await.unwrap(); + }); + + let header = Header::read(&mut server).await.unwrap(); + assert_eq!(header.pdu(), Aspa::PDU); + assert_eq!(header.length(), 8 + 4 + 8); +} + +#[test] +fn aspa_announcement_rejects_empty_provider_list() { + let err = PayloadAspa::new(Asn::from(64496u32), vec![]) + .validate_announcement() + .unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("at least one provider")); +} + +#[test] +fn aspa_announcement_rejects_as0() { + let err = PayloadAspa::new(Asn::from(0u32), vec![Asn::from(64497u32)]) + .validate_announcement() + .unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("AS0")); + + let err = PayloadAspa::new(Asn::from(64496u32), vec![Asn::from(0u32)]) + .validate_announcement() + .unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("AS0")); +} + +#[test] +fn timing_rejects_out_of_range_refresh() { + let err = Timing::new(0, 600, 7200).validate().unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("refresh interval")); +} + +#[test] +fn timing_rejects_expire_not_greater_than_retry_and_refresh() { + let err = Timing::new(600, 600, 600).validate().unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("expire interval")); +} + +#[test] +fn end_of_data_v1_rejects_invalid_timing() { + let err = EndOfDataV1::new(1, 42, 100, Timing::new(600, 8000, 7200)).unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("retry interval")); +} + +#[tokio::test] +async fn end_of_data_v1_read_rejects_invalid_timing() { + let (mut client, mut server) = duplex(1024); + let header = Header::new(1, EndOfDataV1::PDU, 42, END_OF_DATA_V1_LEN); + let mut bytes = Vec::from(header.as_ref()); + bytes.extend_from_slice(&100u32.to_be_bytes()); + bytes.extend_from_slice(&600u32.to_be_bytes()); + bytes.extend_from_slice(&8000u32.to_be_bytes()); + bytes.extend_from_slice(&7200u32.to_be_bytes()); + + client.write_all(&bytes).await.unwrap(); + + let err = EndOfDataV1::read(&mut server).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("retry interval")); +} + +#[tokio::test] +async fn end_of_data_v1_read_payload_rejects_invalid_timing() { + let (mut client, mut server) = duplex(1024); + let header = Header::new(1, EndOfDataV1::PDU, 42, END_OF_DATA_V1_LEN); + client.write_all(&100u32.to_be_bytes()).await.unwrap(); + client.write_all(&600u32.to_be_bytes()).await.unwrap(); + client.write_all(&600u32.to_be_bytes()).await.unwrap(); + client.write_all(&600u32.to_be_bytes()).await.unwrap(); + + let err = EndOfDataV1::read_payload(header, &mut server).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("expire interval")); +} diff --git a/tests/test_session.rs b/tests/test_session.rs index e7ed8e8..fd2d719 100644 --- a/tests/test_session.rs +++ b/tests/test_session.rs @@ -1,11 +1,21 @@ -mod common; +mod common; use std::collections::VecDeque; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::fs::File; +use std::io::BufReader; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; -use tokio::net::TcpListener; +use rustls::{ClientConfig, RootCertStore}; +use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{broadcast, watch}; +use tokio::task::JoinHandle; +use tokio::time::{timeout, Duration}; +use tokio_rustls::{TlsAcceptor, TlsConnector}; use common::test_helper::{ dump_cache_reset, dump_cache_response, dump_eod_v1, dump_ipv4_prefix, dump_ipv6_prefix, @@ -13,17 +23,257 @@ use common::test_helper::{ }; use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; -use rpki::rtr::cache::{Delta, SharedRtrCache, RtrCacheBuilder, Snapshot}; -use rpki::rtr::payload::{Payload, RouteOrigin, Timing}; +use rpki::data_model::resources::as_resources::Asn; +use rpki::rtr::cache::{Delta, RtrCacheBuilder, SessionIds, SharedRtrCache, Snapshot}; +use rpki::rtr::error_type::ErrorCode; +use rpki::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Ski, Timing}; use rpki::rtr::pdu::{ - CacheResponse, CacheReset, EndOfDataV1, IPv4Prefix, IPv6Prefix, ResetQuery, SerialQuery, + Aspa as AspaPdu, CacheReset, CacheResponse, EndOfDataV1, ErrorReport, Header, IPv4Prefix, + IPv6Prefix, ResetQuery, RouterKey as RouterKeyPdu, SerialNotify, SerialQuery, }; +use rpki::rtr::server::connection::handle_tls_connection; +use rpki::rtr::server::tls::load_rustls_server_config_with_options; use rpki::rtr::session::RtrSession; fn shared_cache(cache: rpki::rtr::cache::RtrCache) -> SharedRtrCache { Arc::new(RwLock::new(cache)) } +async fn start_session_server( + cache: SharedRtrCache, +) -> ( + SocketAddr, + broadcast::Sender<()>, + watch::Sender, + JoinHandle<()>, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let (notify_tx, notify_rx) = broadcast::channel(16); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let handle = tokio::spawn(async move { + let Ok((stream, _)) = listener.accept().await else { + return; + }; + + let session = RtrSession::new(cache, stream, notify_rx, shutdown_rx); + let _ = session.run().await; + }); + + (addr, notify_tx, shutdown_tx, handle) +} + +async fn start_session_server_with_transport_timeout( + cache: SharedRtrCache, + transport_timeout: Duration, +) -> ( + SocketAddr, + broadcast::Sender<()>, + watch::Sender, + JoinHandle<()>, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let (notify_tx, notify_rx) = broadcast::channel(16); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let handle = tokio::spawn(async move { + let Ok((stream, _)) = listener.accept().await else { + return; + }; + + let session = RtrSession::new(cache, stream, notify_rx, shutdown_rx) + .with_transport_timeout(transport_timeout); + let _ = session.run().await; + }); + + (addr, notify_tx, shutdown_tx, handle) +} + +async fn start_session_server_returning_result( + cache: SharedRtrCache, +) -> ( + SocketAddr, + watch::Sender, + JoinHandle>, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let session = RtrSession::new(cache, stream, notify_rx, shutdown_rx); + session.run().await + }); + + (addr, shutdown_tx, handle) +} + +async fn start_tls_session_server( + cache: SharedRtrCache, +) -> ( + SocketAddr, + watch::Sender, + JoinHandle<()>, +) { + start_tls_session_server_with_cert(cache, "server.crt", "server.key").await +} + +async fn start_tls_session_server_with_cert( + cache: SharedRtrCache, + cert_name: &str, + key_name: &str, +) -> ( + SocketAddr, + watch::Sender, + JoinHandle<()>, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let tls_config = Arc::new( + load_rustls_server_config_with_options( + fixture_path(cert_name), + fixture_path(key_name), + fixture_path("client-ca.crt"), + false, + ) + .unwrap(), + ); + let acceptor = TlsAcceptor::from(tls_config); + + let handle = tokio::spawn(async move { + let Ok((stream, peer_addr)) = listener.accept().await else { + return; + }; + + let _ = handle_tls_connection(cache, stream, peer_addr, acceptor, notify_rx, shutdown_rx).await; + }); + + (addr, shutdown_tx, handle) +} + +async fn shutdown_server( + mut client: TcpStream, + shutdown_tx: watch::Sender, + server_handle: JoinHandle<()>, +) { + shutdown_io(&mut client, shutdown_tx, server_handle).await; +} + +async fn shutdown_io( + io: &mut S, + shutdown_tx: watch::Sender, + server_handle: JoinHandle<()>, +) where + S: AsyncWrite + Unpin, +{ + let _ = io.shutdown().await; + + let _ = shutdown_tx.send(true); + + match timeout(Duration::from_secs(1), server_handle).await { + Ok(join_res) => { + let _ = join_res; + } + Err(_) => { + panic!("server task did not exit within timeout"); + } + } +} + +fn fixture_path(name: &str) -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures") + .join("tls") + .join(name) +} + +fn load_pem_certs(path: &Path) -> Vec> { + let file = File::open(path).unwrap(); + let mut reader = BufReader::new(file); + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .unwrap() +} + +fn load_pem_key(path: &Path) -> PrivateKeyDer<'static> { + let file = File::open(path).unwrap(); + let mut reader = BufReader::new(file); + rustls_pemfile::private_key(&mut reader) + .unwrap() + .expect("missing PEM private key") +} + +async fn connect_tls_client( + addr: SocketAddr, + cert_name: &str, + key_name: &str, +) -> tokio_rustls::client::TlsStream { + connect_tls_client_with_server_name( + addr, + cert_name, + key_name, + ServerName::IpAddress(addr.ip().into()), + ) + .await +} + +async fn connect_tls_client_with_server_name( + addr: SocketAddr, + cert_name: &str, + key_name: &str, + server_name: ServerName<'static>, +) -> tokio_rustls::client::TlsStream { + let mut roots = RootCertStore::empty(); + for cert in load_pem_certs(&fixture_path("client-ca.crt")) { + roots.add(cert).unwrap(); + } + + let certs = load_pem_certs(&fixture_path(cert_name)); + let key = load_pem_key(&fixture_path(key_name)); + let client_config = ClientConfig::builder() + .with_root_certificates(roots) + .with_client_auth_cert(certs, key) + .unwrap(); + + let connector = TlsConnector::from(Arc::new(client_config)); + let tcp = TcpStream::connect(addr).await.unwrap(); + connector.connect(server_name, tcp).await.unwrap() +} + +/// 用于 dump Serial Notify,保持输出风格一致。 +fn dump_serial_notify(notify: &SerialNotify) -> serde_json::Value { + json!({ + "version": notify.version(), + "pdu": notify.pdu(), + "pdu_name": "Serial Notify", + "session_id": notify.session_id(), + "serial_number": notify.serial_number(), + }) +} + +fn assert_error_report_matches( + report: &ErrorReport, + version: u8, + code: ErrorCode, + offending_pdu: &[u8], +) { + assert_eq!(report.version(), version); + assert_eq!(report.error_code(), Ok(code)); + assert_eq!(report.erroneous_pdu(), offending_pdu); +} + +/// 测试:Reset Query 会返回完整 snapshot,并以 End of Data 结束响应。 #[tokio::test] async fn reset_query_returns_snapshot_and_end_of_data() { let prefix = IPAddressPrefix { @@ -34,26 +284,16 @@ async fn reset_query_returns_snapshot_and_end_of_data() { let snapshot = Snapshot::from_payloads(vec![Payload::RouteOrigin(origin)]); let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::new(600, 600, 7200)) .snapshot(snapshot) .build(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let server_cache = shared_cache(cache); - let (_notify_tx, notify_rx) = broadcast::channel(16); - let (_shutdown_tx, shutdown_rx) = watch::channel(false); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); - session.run().await.unwrap(); - }); - - let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + let mut client = TcpStream::connect(addr).await.unwrap(); ResetQuery::new(1).write(&mut client).await.unwrap(); let mut dump = RtrDebugDumper::new(); @@ -86,12 +326,45 @@ async fn reset_query_returns_snapshot_and_end_of_data() { assert_eq!(timing.expire, 7200); dump.print_pretty("reset_query_returns_snapshot_and_end_of_data"); + shutdown_server(client, shutdown_tx, server_handle).await; } +#[tokio::test] +async fn reset_query_uses_version_specific_session_id() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([40, 41, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + ResetQuery::new(2).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.version(), 2); + assert_eq!(response.session_id(), 42); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.version(), 2); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + dump.print_pretty("reset_query_uses_version_specific_session_id"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +/// 测试:当 Serial Query 的 session_id 和 serial 都与当前 cache 一致时,仅返回 End of Data。 #[tokio::test] async fn serial_query_returns_end_of_data_when_up_to_date() { let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing { refresh: 600, @@ -100,20 +373,10 @@ async fn serial_query_returns_end_of_data_when_up_to_date() { }) .build(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let server_cache = shared_cache(cache); - let (_notify_tx, notify_rx) = broadcast::channel(16); - let (_shutdown_tx, shutdown_rx) = watch::channel(false); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); - session.run().await.unwrap(); - }); - - let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + let mut client = TcpStream::connect(addr).await.unwrap(); SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); let mut dump = RtrDebugDumper::new(); @@ -131,12 +394,14 @@ async fn serial_query_returns_end_of_data_when_up_to_date() { assert_eq!(timing.expire, 7200); dump.print_pretty("serial_query_returns_end_of_data_when_up_to_date"); + shutdown_server(client, shutdown_tx, server_handle).await; } +/// 测试:当已建立 session 后收到错误的 session_id 时,返回 CorruptData 并关闭连接。 #[tokio::test] -async fn serial_query_returns_cache_reset_when_session_id_mismatch() { +async fn serial_query_returns_corrupt_data_when_session_id_mismatch() { let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing { refresh: 600, @@ -145,32 +410,38 @@ async fn serial_query_returns_cache_reset_when_session_id_mismatch() { }) .build(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let server_cache = shared_cache(cache); - let (_notify_tx, notify_rx) = broadcast::channel(16); - let (_shutdown_tx, shutdown_rx) = watch::channel(false); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); - session.run().await.unwrap(); - }); - - let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + let mut client = TcpStream::connect(addr).await.unwrap(); SerialQuery::new(1, 999, 100).write(&mut client).await.unwrap(); let mut dump = RtrDebugDumper::new(); - let reset = CacheReset::read(&mut client).await.unwrap(); - dump.push_value(reset.pdu(), dump_cache_reset(reset.version(), reset.pdu())); - assert_eq!(reset.pdu(), 8); - assert_eq!(reset.version(), 1); + let report = ErrorReport::read(&mut client).await.unwrap(); + assert_error_report_matches( + &report, + 1, + ErrorCode::CorruptData, + SerialQuery::new(1, 999, 100).as_ref(), + ); - dump.print_pretty("serial_query_returns_cache_reset_when_session_id_mismatch"); + let read_res = Header::read(&mut client).await; + assert!(read_res.is_err()); + + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_corrupt_session_id", + "result": "header_read_failed_as_expected" + }), + ); + + dump.print_pretty("serial_query_returns_corrupt_data_when_session_id_mismatch"); + shutdown_server(client, shutdown_tx, server_handle).await; } +/// 测试:当增量更新可用时,Serial Query 返回 Cache Response + delta payload + End of Data。 #[tokio::test] async fn serial_query_returns_deltas_when_incremental_update_available() { let prefix = IPAddressPrefix { @@ -179,17 +450,13 @@ async fn serial_query_returns_deltas_when_incremental_update_available() { }; let origin = RouteOrigin::new(prefix, 24, 64496u32.into()); - let delta = Arc::new(Delta::new( - 101, - vec![Payload::RouteOrigin(origin)], - vec![], - )); + let delta = Arc::new(Delta::new(101, vec![Payload::RouteOrigin(origin)], vec![])); let mut deltas = VecDeque::new(); deltas.push_back(delta); let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(101) .timing(Timing { refresh: 600, @@ -199,20 +466,10 @@ async fn serial_query_returns_deltas_when_incremental_update_available() { .deltas(deltas) .build(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let server_cache = shared_cache(cache); - let (_notify_tx, notify_rx) = broadcast::channel(16); - let (_shutdown_tx, shutdown_rx) = watch::channel(false); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); - session.run().await.unwrap(); - }); - - let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + let mut client = TcpStream::connect(addr).await.unwrap(); SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); let mut dump = RtrDebugDumper::new(); @@ -246,8 +503,112 @@ async fn serial_query_returns_deltas_when_incremental_update_available() { assert_eq!(timing.expire, 7200); dump.print_pretty("serial_query_returns_deltas_when_incremental_update_available"); + shutdown_server(client, shutdown_tx, server_handle).await; } +#[tokio::test] +async fn serial_query_returns_deltas_across_serial_wraparound() { + let first_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)), + prefix_length: 24, + }; + let first_origin = RouteOrigin::new(first_prefix, 24, 64496u32.into()); + + let second_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(198, 51, 100, 0)), + prefix_length: 24, + }; + let second_origin = RouteOrigin::new(second_prefix, 24, 64497u32.into()); + + let d_max = Arc::new(Delta::new( + u32::MAX, + vec![Payload::RouteOrigin(first_origin.clone())], + vec![], + )); + let d_zero = Arc::new(Delta::new( + 0, + vec![Payload::RouteOrigin(second_origin.clone())], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d_max); + deltas.push_back(d_zero); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(first_origin), + Payload::RouteOrigin(second_origin), + ]); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(0) + .timing(Timing::new(600, 600, 7200)) + .snapshot(snapshot) + .deltas(deltas) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + SerialQuery::new(1, 42, u32::MAX.wrapping_sub(1)) + .write(&mut client) + .await + .unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let first = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(first.pdu(), dump_ipv4_prefix(&first)); + assert!(first.flag().is_announce()); + assert_eq!(first.prefix(), Ipv4Addr::new(198, 51, 100, 0)); + + let second = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(second.pdu(), dump_ipv4_prefix(&second)); + assert!(second.flag().is_announce()); + assert_eq!(second.prefix(), Ipv4Addr::new(192, 0, 2, 0)); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 0); + + dump.print_pretty("serial_query_returns_deltas_across_serial_wraparound"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn serial_query_returns_cache_reset_for_future_serial_across_wraparound() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(u32::MAX) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + SerialQuery::new(1, 42, 0).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let reset = CacheReset::read(&mut client).await.unwrap(); + dump.push_value(reset.pdu(), dump_cache_reset(reset.version(), reset.pdu())); + assert_eq!(reset.pdu(), CacheReset::PDU); + assert_eq!(reset.version(), 1); + + dump.print_pretty("serial_query_returns_cache_reset_for_future_serial_across_wraparound"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +/// 测试:Reset Query 返回的 payload 顺序符合当前实现的 RTR 排序规则。 #[tokio::test] async fn reset_query_returns_payloads_in_rtr_order() { let v4_low_prefix = IPAddressPrefix { @@ -275,26 +636,16 @@ async fn reset_query_returns_payloads_in_rtr_order() { ]); let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(100) .timing(Timing::new(600, 600, 7200)) .snapshot(snapshot) .build(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let server_cache = shared_cache(cache); - let (_notify_tx, notify_rx) = broadcast::channel(16); - let (_shutdown_tx, shutdown_rx) = watch::channel(false); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); - session.run().await.unwrap(); - }); - - let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + let mut client = TcpStream::connect(addr).await.unwrap(); ResetQuery::new(1).write(&mut client).await.unwrap(); let mut dump = RtrDebugDumper::new(); @@ -350,8 +701,10 @@ async fn reset_query_returns_payloads_in_rtr_order() { assert_eq!(timing.expire, 7200); dump.print_pretty("reset_query_returns_payloads_in_rtr_order"); + shutdown_server(client, shutdown_tx, server_handle).await; } +/// 测试:Serial Query 返回的增量中,announcement 在前,withdrawal 在后,且各自内部顺序符合当前实现。 #[tokio::test] async fn serial_query_returns_announcements_before_withdrawals() { let announced_low_prefix = IPAddressPrefix { @@ -394,7 +747,7 @@ async fn serial_query_returns_announcements_before_withdrawals() { deltas.push_back(delta); let cache = RtrCacheBuilder::new() - .session_id(42) + .session_ids(SessionIds::from_array([42, 42, 42])) .serial(101) .timing(Timing { refresh: 600, @@ -404,20 +757,10 @@ async fn serial_query_returns_announcements_before_withdrawals() { .deltas(deltas) .build(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let server_cache = shared_cache(cache); - let (_notify_tx, notify_rx) = broadcast::channel(16); - let (_shutdown_tx, shutdown_rx) = watch::channel(false); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); - session.run().await.unwrap(); - }); - - let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + let mut client = TcpStream::connect(addr).await.unwrap(); SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); let mut dump = RtrDebugDumper::new(); @@ -489,4 +832,951 @@ async fn serial_query_returns_announcements_before_withdrawals() { assert_eq!(timing.expire, 7200); dump.print_pretty("serial_query_returns_announcements_before_withdrawals"); -} \ No newline at end of file + shutdown_server(client, shutdown_tx, server_handle).await; +} + +/// 测试:session 建立后,收到 notify 广播时会发送 Serial Notify。 +#[tokio::test] +async fn established_session_sends_serial_notify() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + notify_tx.send(()).unwrap(); + + let notify = SerialNotify::read(&mut client).await.unwrap(); + dump.push_value(notify.pdu(), dump_serial_notify(¬ify)); + assert_eq!(notify.pdu(), SerialNotify::PDU); + assert_eq!(notify.version(), 1); + assert_eq!(notify.session_id(), 42); + assert_eq!(notify.serial_number(), 100); + + dump.print_pretty("established_session_sends_serial_notify"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +/// 测试:首个 PDU 版本过高时,返回 UnsupportedProtocolVersion 错误并关闭连接。 +#[tokio::test] +async fn first_pdu_with_too_high_version_returns_unsupported_version_error() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + + ResetQuery::new(3).write(&mut client).await.unwrap(); + + let report = ErrorReport::read(&mut client).await.unwrap(); + assert_error_report_matches( + &report, + 2, + ErrorCode::UnsupportedProtocolVersion, + ResetQuery::new(3).as_ref(), + ); + + let read_res = Header::read(&mut client).await; + assert!(read_res.is_err()); + + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_unsupported_protocol_version", + "result": "header_read_failed_as_expected" + }), + ); + + dump.print_pretty("first_pdu_with_too_high_version_returns_unsupported_version_error"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +/// 测试:版本协商完成后,如果后续请求更换了协议版本,返回 UnexpectedProtocolVersion 并关闭连接。 +#[tokio::test] +async fn session_rejects_version_change_after_negotiation() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + ResetQuery::new(2).write(&mut client).await.unwrap(); + + let report = ErrorReport::read(&mut client).await.unwrap(); + assert_error_report_matches( + &report, + 1, + ErrorCode::UnexpectedProtocolVersion, + ResetQuery::new(2).as_ref(), + ); + + let read_res = Header::read(&mut client).await; + assert!(read_res.is_err()); + + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_unexpected_protocol_version", + "result": "header_read_failed_as_expected" + }), + ); + + dump.print_pretty("session_rejects_version_change_after_negotiation"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +/// 测试:在版本协商完成前,即使收到 notify 广播,也不能发送 Serial Notify。 +#[tokio::test] +async fn notify_is_not_sent_before_version_negotiation() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + notify_tx.send(()).unwrap(); + + let res = timeout(Duration::from_millis(100), SerialNotify::read(&mut client)).await; + assert!( + res.is_err(), + "serial notify should not be sent before version negotiation" + ); + + dump.push_value( + 0, + json!({ + "event": "notify_before_version_negotiation", + "result": "no_serial_notify_received_within_timeout", + "timeout_ms": 100 + }), + ); + + dump.print_pretty("notify_is_not_sent_before_version_negotiation"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +/// 测试:同一 session 在一分钟窗口内连续收到 notify 广播时,只会发送一个 Serial Notify。 +#[tokio::test] +async fn serial_notify_is_rate_limited_to_once_per_minute() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.pdu(), 3); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.pdu(), 7); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + notify_tx.send(()).unwrap(); + + let first_notify = SerialNotify::read(&mut client).await.unwrap(); + dump.push_value(first_notify.pdu(), dump_serial_notify(&first_notify)); + assert_eq!(first_notify.pdu(), SerialNotify::PDU); + assert_eq!(first_notify.version(), 1); + assert_eq!(first_notify.session_id(), 42); + assert_eq!(first_notify.serial_number(), 100); + + notify_tx.send(()).unwrap(); + + let second_res = timeout(Duration::from_millis(100), SerialNotify::read(&mut client)).await; + assert!( + second_res.is_err(), + "second serial notify should be rate-limited within one minute" + ); + + dump.push_value( + 0, + json!({ + "event": "second_notify_within_rate_limit_window", + "result": "no_second_serial_notify_received_within_timeout", + "timeout_ms": 100 + }), + ); + + dump.print_pretty("serial_notify_is_rate_limited_to_once_per_minute"); + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn reset_query_returns_no_data_available_when_cache_is_unavailable() { + let cache = RtrCacheBuilder::new() + .availability(rpki::rtr::cache::CacheAvailability::NoDataAvailable) + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let report = ErrorReport::read(&mut client).await.unwrap(); + assert_error_report_matches( + &report, + 1, + ErrorCode::NoDataAvailable, + ResetQuery::new(1).as_ref(), + ); + + ResetQuery::new(1).write(&mut client).await.unwrap(); + let second = ErrorReport::read(&mut client).await.unwrap(); + assert_error_report_matches( + &second, + 1, + ErrorCode::NoDataAvailable, + ResetQuery::new(1).as_ref(), + ); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn serial_query_returns_no_data_available_when_cache_is_unavailable() { + let cache = RtrCacheBuilder::new() + .availability(rpki::rtr::cache::CacheAvailability::NoDataAvailable) + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); + + let report = ErrorReport::read(&mut client).await.unwrap(); + assert_error_report_matches( + &report, + 1, + ErrorCode::NoDataAvailable, + SerialQuery::new(1, 42, 100).as_ref(), + ); + + SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); + let second = ErrorReport::read(&mut client).await.unwrap(); + assert_error_report_matches( + &second, + 1, + ErrorCode::NoDataAvailable, + SerialQuery::new(1, 42, 100).as_ref(), + ); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn first_pdu_with_invalid_length_returns_corrupt_data() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, _shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + + let header = Header::new(1, ResetQuery::PDU, 0, 9); + let mut request = Vec::from(header.as_ref()); + request.push(0); + + timeout(Duration::from_secs(1), client.write_all(&request)) + .await + .expect("write_all timed out") + .unwrap(); + dump.push_value( + ResetQuery::PDU, + json!({ + "event": "invalid_first_pdu_sent", + "raw_hex": common::test_helper::bytes_to_hex(&request), + "length": request.len(), + }), + ); + + let report = timeout(Duration::from_secs(1), ErrorReport::read(&mut client)) + .await + .expect("timed out waiting for ErrorReport") + .unwrap(); + dump.push_value( + ErrorReport::PDU, + json!({ + "version": report.version(), + "pdu": ErrorReport::PDU, + "pdu_name": "Error Report", + "error_code": report.error_code().map(|code| code.as_u16()).unwrap_or_else(|code| code), + "erroneous_pdu_len": report.erroneous_pdu().len(), + "erroneous_pdu_hex": common::test_helper::bytes_to_hex(report.erroneous_pdu()), + "text": String::from_utf8_lossy(report.text()), + }), + ); + + assert_error_report_matches(&report, 1, ErrorCode::CorruptData, &request); + assert!(std::str::from_utf8(report.text()).unwrap().contains("invalid length")); + + let read_res = timeout(Duration::from_secs(1), Header::read(&mut client)) + .await + .expect("timed out waiting for connection close"); + assert!(read_res.is_err()); + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_invalid_first_pdu", + "result": "header_read_failed_as_expected" + }), + ); + dump.print_pretty("first_pdu_with_invalid_length_returns_corrupt_data"); + + drop(client); + server_handle.abort(); + let _ = server_handle.await; +} + +#[tokio::test] +async fn established_session_closes_after_receiving_error_report() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + + let report = ErrorReport::new(1, ErrorCode::InternalError.as_u16(), [], b"peer error"); + client.write_all(report.as_ref()).await.unwrap(); + dump.push_value( + ErrorReport::PDU, + json!({ + "event": "peer_error_report_sent", + "version": report.version(), + "error_code": report.error_code().map(|code| code.as_u16()).unwrap_or_else(|code| code), + "text": String::from_utf8_lossy(report.text()), + }), + ); + + let read_res = timeout(Duration::from_secs(1), Header::read(&mut client)) + .await + .unwrap(); + assert!(read_res.is_err()); + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_receiving_peer_error_report", + "result": "header_read_failed_as_expected" + }), + ); + dump.print_pretty("established_session_closes_after_receiving_error_report"); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn established_session_invalid_header_returns_corrupt_data() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + + let invalid_header = Header::new(1, SerialQuery::PDU, 42, 7); + client.write_all(invalid_header.as_ref()).await.unwrap(); + dump.push_value( + SerialQuery::PDU, + json!({ + "event": "invalid_header_sent_after_establishment", + "raw_hex": common::test_helper::bytes_to_hex(invalid_header.as_ref()), + "length": invalid_header.length(), + }), + ); + + let report = ErrorReport::read(&mut client).await.unwrap(); + dump.push_value( + ErrorReport::PDU, + json!({ + "version": report.version(), + "pdu": ErrorReport::PDU, + "pdu_name": "Error Report", + "error_code": report.error_code().map(|code| code.as_u16()).unwrap_or_else(|code| code), + "erroneous_pdu_len": report.erroneous_pdu().len(), + "erroneous_pdu_hex": common::test_helper::bytes_to_hex(report.erroneous_pdu()), + "text": String::from_utf8_lossy(report.text()), + }), + ); + assert_error_report_matches(&report, 1, ErrorCode::CorruptData, invalid_header.as_ref()); + assert!(std::str::from_utf8(report.text()) + .unwrap() + .contains("invalid PDU length")); + + let read_res = Header::read(&mut client).await; + assert!(read_res.is_err()); + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_invalid_established_header", + "result": "header_read_failed_as_expected" + }), + ); + dump.print_pretty("established_session_invalid_header_returns_corrupt_data"); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn established_session_unknown_pdu_returns_unsupported_pdu_type() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + + let unknown_pdu = Header::new(1, 12, 0, 8); + client.write_all(unknown_pdu.as_ref()).await.unwrap(); + dump.push_value( + 12, + json!({ + "event": "unknown_pdu_sent_after_establishment", + "raw_hex": common::test_helper::bytes_to_hex(unknown_pdu.as_ref()), + "length": unknown_pdu.length(), + }), + ); + + let report = ErrorReport::read(&mut client).await.unwrap(); + dump.push_value( + ErrorReport::PDU, + json!({ + "version": report.version(), + "pdu": ErrorReport::PDU, + "pdu_name": "Error Report", + "error_code": report.error_code().map(|code| code.as_u16()).unwrap_or_else(|code| code), + "erroneous_pdu_len": report.erroneous_pdu().len(), + "erroneous_pdu_hex": common::test_helper::bytes_to_hex(report.erroneous_pdu()), + "text": String::from_utf8_lossy(report.text()), + }), + ); + + assert_error_report_matches(&report, 1, ErrorCode::UnsupportedPduType, unknown_pdu.as_ref()); + + let read_res = Header::read(&mut client).await; + assert!(read_res.is_err()); + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_unknown_pdu", + "result": "header_read_failed_as_expected" + }), + ); + dump.print_pretty("established_session_unknown_pdu_returns_unsupported_pdu_type"); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn version_zero_does_not_send_router_key_or_aspa() { + let router_key = RouterKey::new(Ski::default(), Asn::from(64496u32), vec![1u8; 32]); + let aspa = Aspa::new(Asn::from(64496u32), vec![Asn::from(64497u32), Asn::from(64498u32)]); + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouterKey(router_key), + Payload::Aspa(aspa), + ]); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .snapshot(snapshot) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(0).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + + let eod = rpki::rtr::pdu::EndOfDataV0::read(&mut client).await.unwrap(); + dump.push_value( + eod.pdu(), + json!({ + "version": eod.version(), + "pdu": eod.pdu(), + "pdu_name": "End of Data", + "session_id": eod.session_id(), + "serial_number": eod.serial_number(), + }), + ); + + let res = timeout(Duration::from_millis(100), Header::read(&mut client)).await; + assert!(res.is_err(), "version 0 response should not contain RouterKey or ASPA PDUs"); + dump.print_pretty("version_zero_does_not_send_router_key_or_aspa"); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn version_two_aspa_withdraw_has_empty_provider_list() { + let aspa = Aspa::new(Asn::from(64496u32), vec![Asn::from(64497u32), Asn::from(64498u32)]); + let delta = Arc::new(Delta::new(101, vec![], vec![Payload::Aspa(aspa)])); + let mut deltas = VecDeque::new(); + deltas.push_back(delta); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(101) + .timing(Timing::new(600, 600, 7200)) + .deltas(deltas) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + SerialQuery::new(2, 42, 100).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + + let header = Header::read(&mut client).await.unwrap(); + assert_eq!(header.pdu(), AspaPdu::PDU); + assert_eq!(header.length(), 12); + let mut body = [0u8; 4]; + client.read_exact(&mut body).await.unwrap(); + assert_eq!(u32::from_be_bytes(body), 64496); + dump.push_value( + AspaPdu::PDU, + json!({ + "version": header.version(), + "pdu": header.pdu(), + "length": header.length(), + "customer_asn": u32::from_be_bytes(body), + "withdraw": true, + }), + ); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + dump.print_pretty("version_two_aspa_withdraw_has_empty_provider_list"); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn version_one_sends_router_key_but_not_aspa() { + let router_key = RouterKey::new(Ski::default(), Asn::from(64496u32), vec![1u8; 32]); + let aspa = Aspa::new(Asn::from(64496u32), vec![Asn::from(64497u32)]); + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouterKey(router_key), + Payload::Aspa(aspa), + ]); + + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .snapshot(snapshot) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + + let header = Header::read(&mut client).await.unwrap(); + assert_eq!(header.pdu(), RouterKeyPdu::PDU); + dump.push_value( + RouterKeyPdu::PDU, + json!({ + "version": header.version(), + "pdu": header.pdu(), + "length": header.length(), + }), + ); + + let payload_len = header.length() as usize - 8; + let mut payload = vec![0u8; payload_len]; + client.read_exact(&mut payload).await.unwrap(); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + + let res = timeout(Duration::from_millis(100), Header::read(&mut client)).await; + assert!(res.is_err(), "version 1 response should not contain ASPA PDUs"); + dump.print_pretty("version_one_sends_router_key_but_not_aspa"); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn established_session_idle_timeout_returns_transport_failed() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, _notify_tx, shutdown_tx, server_handle) = + start_session_server_with_transport_timeout(server_cache, Duration::from_millis(100)).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + + let report = timeout(Duration::from_secs(1), ErrorReport::read(&mut client)) + .await + .expect("timed out waiting for transport failure ErrorReport") + .unwrap(); + dump.push_value( + ErrorReport::PDU, + json!({ + "version": report.version(), + "pdu": ErrorReport::PDU, + "pdu_name": "Error Report", + "error_code": report.error_code().map(|code| code.as_u16()).unwrap_or_else(|code| code), + "erroneous_pdu_len": report.erroneous_pdu().len(), + "erroneous_pdu_hex": common::test_helper::bytes_to_hex(report.erroneous_pdu()), + "text": String::from_utf8_lossy(report.text()), + }), + ); + + assert_eq!(report.version(), 1); + assert_eq!(report.error_code(), Ok(ErrorCode::TransportFailed)); + assert!(report.erroneous_pdu().is_empty()); + assert!(std::str::from_utf8(report.text()).unwrap().contains("transport stalled")); + + let read_res = Header::read(&mut client).await; + assert!(read_res.is_err()); + dump.push_value( + 0, + json!({ + "event": "connection_closed_after_transport_timeout", + "result": "header_read_failed_as_expected" + }), + ); + dump.print_pretty("established_session_idle_timeout_returns_transport_failed"); + + shutdown_server(client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn tls_client_with_matching_san_ip_is_accepted() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, shutdown_tx, server_handle) = start_tls_session_server(server_cache).await; + + let mut client = connect_tls_client(addr, "client-good.crt", "client-good.key").await; + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + dump.print_pretty("tls_client_with_matching_san_ip_is_accepted"); + shutdown_io(&mut client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn tls_client_accepts_server_certificate_with_dns_san() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, shutdown_tx, server_handle) = + start_tls_session_server_with_cert(server_cache, "server-dns.crt", "server-dns.key").await; + + let mut client = connect_tls_client_with_server_name( + addr, + "client-good.crt", + "client-good.key", + ServerName::try_from("localhost").unwrap(), + ) + .await; + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + dump.print_pretty("tls_client_accepts_server_certificate_with_dns_san"); + shutdown_io(&mut client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn tls_server_dns_name_san_strict_mode_rejects_ip_only_certificate() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let err = load_rustls_server_config_with_options( + fixture_path("server.crt"), + fixture_path("server.key"), + fixture_path("client-ca.crt"), + true, + ) + .unwrap_err(); + + assert!(err + .to_string() + .contains("does not contain a subjectAltName dNSName entry")); + + let _ = cache; +} + +#[tokio::test] +async fn tls_client_with_mismatched_san_ip_is_rejected() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, shutdown_tx, server_handle) = start_tls_session_server(server_cache).await; + + let mut client = connect_tls_client(addr, "client-bad.crt", "client-bad.key").await; + let mut dump = RtrDebugDumper::new(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + dump.push_value( + ResetQuery::PDU, + json!({ + "event": "tls_reset_query_sent_with_bad_client_cert", + "version": 1, + }), + ); + + let read_res = timeout(Duration::from_secs(1), Header::read(&mut client)) + .await + .expect("timed out waiting for TLS session close"); + assert!(read_res.is_err(), "server should close TLS session when client SAN IP mismatches"); + dump.push_value( + 0, + json!({ + "event": "tls_session_closed_after_san_ip_mismatch", + "result": "header_read_failed_as_expected" + }), + ); + dump.print_pretty("tls_client_with_mismatched_san_ip_is_rejected"); + + shutdown_io(&mut client, shutdown_tx, server_handle).await; +} + +#[tokio::test] +async fn invalid_timing_prevents_end_of_data_response() { + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 8000, 7200)) + .build(); + + let server_cache = shared_cache(cache); + let (addr, shutdown_tx, server_handle) = start_session_server_returning_result(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let read_res = timeout(Duration::from_secs(1), Header::read(&mut client)) + .await + .expect("timed out waiting for server close"); + assert!(read_res.is_err(), "server should close instead of sending invalid EndOfData"); + + let _ = shutdown_tx.send(true); + let join = timeout(Duration::from_secs(1), server_handle) + .await + .expect("server task did not exit within timeout") + .unwrap(); + let err = join.expect_err("session should fail on invalid timing"); + assert!(err.to_string().contains("retry interval")); +} + +#[tokio::test] +async fn invalid_aspa_prevents_snapshot_response() { + let snapshot = Snapshot::from_payloads(vec![Payload::Aspa(Aspa::new( + Asn::from(64496u32), + vec![], + ))]); + let cache = RtrCacheBuilder::new() + .session_ids(SessionIds::from_array([42, 42, 42])) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .snapshot(snapshot) + .build(); + + let server_cache = shared_cache(cache); + let (addr, shutdown_tx, server_handle) = start_session_server_returning_result(server_cache).await; + + let mut client = TcpStream::connect(addr).await.unwrap(); + ResetQuery::new(2).write(&mut client).await.unwrap(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + assert_eq!(response.version(), 2); + assert_eq!(response.session_id(), 42); + + let read_res = timeout(Duration::from_secs(1), Header::read(&mut client)) + .await + .expect("timed out waiting for server close"); + assert!(read_res.is_err(), "server should close instead of sending invalid ASPA"); + + let _ = shutdown_tx.send(true); + let join = timeout(Duration::from_secs(1), server_handle) + .await + .expect("server task did not exit within timeout") + .unwrap(); + let err = join.expect_err("session should fail on invalid ASPA"); + assert!(err.to_string().contains("ASPA announcement")); +} diff --git a/tests/test_store_db.rs b/tests/test_store_db.rs index a1550d2..42e6274 100644 --- a/tests/test_store_db.rs +++ b/tests/test_store_db.rs @@ -6,9 +6,9 @@ use common::test_helper::{ indent_block, payloads_to_string, test_report, v4_origin, v6_origin, }; -use rpki::rtr::cache::{Delta, Snapshot}; +use rpki::rtr::cache::{CacheAvailability, Delta, SessionIds, Snapshot}; use rpki::rtr::payload::Payload; -use rpki::rtr::store_db::RtrStore; +use rpki::rtr::store::RtrStore; fn snapshot_to_string(snapshot: &Snapshot) -> String { let payloads = snapshot.payloads_for_rtr(); @@ -69,37 +69,52 @@ fn store_db_save_and_get_snapshot() { fn store_db_set_and_get_meta_fields() { let dir = tempfile::tempdir().unwrap(); let store = RtrStore::open(dir.path()).unwrap(); + let session_ids = SessionIds::from_array([40, 41, 42]); - store.set_session_id(42).unwrap(); + store.set_session_ids(&session_ids).unwrap(); store.set_serial(100).unwrap(); store.set_delta_window(101, 110).unwrap(); - let session_id = store.get_session_id().unwrap(); + let loaded_session_ids = store.get_session_ids().unwrap(); let serial = store.get_serial().unwrap(); let window = store.get_delta_window().unwrap(); let input = format!( - "db_path: {}\nset_session_id=42\nset_serial=100\nset_delta_window=(101, 110)\n", + "db_path: {}\nset_session_ids={:?}\nset_serial=100\nset_delta_window=(101, 110)\n", dir.path().display(), + session_ids, ); let output = format!( - "get_session_id: {:?}\nget_serial: {:?}\nget_delta_window: {:?}\n", - session_id, serial, window, + "get_session_ids: {:?}\nget_serial: {:?}\nget_delta_window: {:?}\n", + loaded_session_ids, serial, window, ); test_report( "store_db_set_and_get_meta_fields", - "验证 session_id / serial / delta_window 能正确写入并读回。", + "验证 session_ids / serial / delta_window 能正确写入并读回。", &input, &output, ); - assert_eq!(session_id, Some(42)); + assert_eq!(loaded_session_ids, Some(session_ids)); assert_eq!(serial, Some(100)); assert_eq!(window, Some((101, 110))); } +#[test] +fn store_db_clear_delta_window_removes_both_bounds() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + store.set_delta_window(101, 110).unwrap(); + assert_eq!(store.get_delta_window().unwrap(), Some((101, 110))); + + store.clear_delta_window().unwrap(); + + assert_eq!(store.get_delta_window().unwrap(), None); +} + #[test] fn store_db_save_and_get_delta() { let dir = tempfile::tempdir().unwrap(); @@ -194,43 +209,220 @@ fn store_db_load_deltas_since_returns_only_newer_deltas_in_order() { fn store_db_save_snapshot_and_meta_writes_all_fields() { let dir = tempfile::tempdir().unwrap(); let store = RtrStore::open(dir.path()).unwrap(); + let session_ids = SessionIds::from_array([40, 41, 42]); let snapshot = Snapshot::from_payloads(vec![ Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)), ]); - store.save_snapshot_and_meta(&snapshot, 42, 100).unwrap(); + store + .save_snapshot_and_meta(&snapshot, &session_ids, 100) + .unwrap(); let loaded_snapshot = store.get_snapshot().unwrap().expect("snapshot should exist"); - let loaded_session = store.get_session_id().unwrap(); + let loaded_session_ids = store.get_session_ids().unwrap(); let loaded_serial = store.get_serial().unwrap(); let input = format!( - "db_path: {}\nsnapshot:\n{}session_id=42\nserial=100\n", + "db_path: {}\nsnapshot:\n{}session_ids={:?}\nserial=100\n", dir.path().display(), indent_block(&snapshot_to_string(&snapshot), 2), + session_ids, ); let output = format!( - "loaded_snapshot:\n{}loaded_session_id: {:?}\nloaded_serial: {:?}\n", + "loaded_snapshot:\n{}loaded_session_ids: {:?}\nloaded_serial: {:?}\n", indent_block(&snapshot_to_string(&loaded_snapshot), 2), - loaded_session, + loaded_session_ids, loaded_serial, ); test_report( "store_db_save_snapshot_and_meta_writes_all_fields", - "验证 save_snapshot_and_meta() 会同时写入 snapshot、session_id 和 serial。", + "验证 save_snapshot_and_meta() 会同时写入 snapshot、session_ids 和 serial。", &input, &output, ); assert!(snapshot.same_content(&loaded_snapshot)); - assert_eq!(loaded_session, Some(42)); + assert_eq!(loaded_session_ids, Some(session_ids)); assert_eq!(loaded_serial, Some(100)); } +#[test] +fn store_db_save_cache_state_writes_delta_snapshot_meta_and_window_together() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + let session_ids = SessionIds::from_array([40, 41, 42]); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)), + ]); + let delta = Delta::new( + 101, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + ); + + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 101, + Some(&delta), + Some((101, 101)), + false, + ) + .unwrap(); + + let loaded_snapshot = store.get_snapshot().unwrap().expect("snapshot should exist"); + let loaded_session_ids = store.get_session_ids().unwrap(); + let loaded_serial = store.get_serial().unwrap(); + let loaded_availability = store.get_availability().unwrap(); + let loaded_delta = store.get_delta(101).unwrap().expect("delta should exist"); + let loaded_window = store.get_delta_window().unwrap(); + + assert!(snapshot.same_content(&loaded_snapshot)); + assert_eq!(loaded_session_ids, Some(session_ids)); + assert_eq!(loaded_serial, Some(101)); + assert_eq!(loaded_availability, Some(CacheAvailability::Ready)); + assert_eq!(loaded_delta.serial(), 101); + assert_eq!(loaded_window, Some((101, 101))); +} + +#[test] +fn store_db_save_cache_state_prunes_deltas_older_than_window_min() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + let session_ids = SessionIds::from_array([40, 41, 42]); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + ]); + + let d101 = Delta::new( + 101, + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + vec![], + ); + let d102 = Delta::new( + 102, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + ); + let d103 = Delta::new( + 103, + vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![], + ); + + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 101, + Some(&d101), + Some((101, 101)), + false, + ) + .unwrap(); + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 102, + Some(&d102), + Some((101, 102)), + false, + ) + .unwrap(); + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 103, + Some(&d103), + Some((103, 103)), + false, + ) + .unwrap(); + + assert!(store.get_delta(101).unwrap().is_none()); + assert!(store.get_delta(102).unwrap().is_none()); + assert_eq!(store.get_delta(103).unwrap().map(|d| d.serial()), Some(103)); + assert_eq!(store.get_delta_window().unwrap(), Some((103, 103))); +} + +#[test] +fn store_db_load_delta_window_restores_wraparound_window_in_serial_order() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + let session_ids = SessionIds::from_array([40, 41, 42]); + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + ]); + + let d_max = Delta::new( + u32::MAX, + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + vec![], + ); + let d_zero = Delta::new( + 0, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + ); + let d_one = Delta::new( + 1, + vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![], + ); + + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + u32::MAX, + Some(&d_max), + Some((u32::MAX, u32::MAX)), + false, + ) + .unwrap(); + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 0, + Some(&d_zero), + Some((u32::MAX, 0)), + false, + ) + .unwrap(); + store + .save_cache_state( + CacheAvailability::Ready, + &snapshot, + &session_ids, + 1, + Some(&d_one), + Some((u32::MAX, 1)), + false, + ) + .unwrap(); + + let loaded = store.load_delta_window(u32::MAX, 1).unwrap(); + + assert_eq!(loaded.iter().map(Delta::serial).collect::>(), vec![u32::MAX, 0, 1]); +} + #[test] fn store_db_load_snapshot_and_serial_returns_consistent_pair() { let dir = tempfile::tempdir().unwrap(); @@ -349,4 +541,4 @@ fn store_db_load_snapshot_and_serial_errors_on_inconsistent_state() { ); assert!(result.is_err()); -} \ No newline at end of file +}