补充rtr-client,完善tls连接等

This commit is contained in:
xiuting.xu 2026-03-25 10:08:40 +08:00
parent 9cbea4e2d0
commit 03c0ab0ec7
45 changed files with 7521 additions and 1835 deletions

View File

@ -29,4 +29,5 @@ tokio-rustls = "0.26"
rustls = "0.23" rustls = "0.23"
rustls-pemfile = "2" rustls-pemfile = "2"
rustls-pki-types = "1.14.0" rustls-pki-types = "1.14.0"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } socket2 = "0.5"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }

View File

@ -1,11 +1,73 @@
# RPKI RTR Server
Default runtime target: Ubuntu/Linux. Windows is only used during development.
# 单元测试 ## Tests
```bash
cargo test
``` ```
cargo test
# 查看输出 To show test output:
```bash
cargo test -- --nocapture cargo test -- --nocapture
``` ```
## RTR Server
The RTR server binary reads its runtime configuration from environment variables.
If an environment variable is not set, the built-in default from `src/main.rs`
is used.
### Environment Variables
| Variable | Description | Example |
| --- | --- | --- |
| `RPKI_RTR_ENABLE_TLS` | Enable TLS listener in addition to TCP. Accepts `true/false`, `1/0`, `yes/no`, `on/off`. | `true` |
| `RPKI_RTR_TCP_ADDR` | TCP bind address. | `0.0.0.0:3323` |
| `RPKI_RTR_TLS_ADDR` | TLS bind address. | `0.0.0.0:3324` |
| `RPKI_RTR_DB_PATH` | RTR RocksDB path. | `./rtr-db` |
| `RPKI_RTR_VRP_FILE` | Input VRP file path. | `./data/vrps.txt` |
| `RPKI_RTR_TLS_CERT_PATH` | TLS server certificate path. | `./certs/server.crt` |
| `RPKI_RTR_TLS_KEY_PATH` | TLS server private key path. | `./certs/server.key` |
| `RPKI_RTR_TLS_CLIENT_CA_PATH` | Client CA certificate path used to verify router certificates. | `./certs/client-ca.crt` |
| `RPKI_RTR_MAX_DELTA` | Maximum retained delta count. | `100` |
| `RPKI_RTR_REFRESH_INTERVAL_SECS` | VRP reload interval in seconds. | `300` |
| `RPKI_RTR_MAX_CONNECTIONS` | Maximum concurrent RTR connections. | `512` |
| `RPKI_RTR_NOTIFY_QUEUE_SIZE` | Broadcast queue size for serial notify events. | `1024` |
| `RPKI_RTR_TCP_KEEPALIVE_SECS` | TCP keepalive time in seconds. Set `0` to disable. | `60` |
| `RPKI_RTR_WARN_INSECURE_TCP` | Emit a warning when plain TCP is enabled. Accepts boolean values. | `true` |
| `RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN` | Strict mode: reject TLS server certificates that do not contain a `subjectAltName dNSName`. Accepts boolean values. | `false` |
### Notes
- Plain TCP should only be used on a trusted and controlled network.
- TLS mode requires client certificate authentication.
- In strict TLS server certificate mode, a server certificate without
`subjectAltName dNSName` will be rejected during startup.
- `RPKI_RTR_TCP_KEEPALIVE_SECS=0` disables TCP keepalive. Any non-zero value
enables keepalive for the lifetime of each accepted socket.
## Example Startup
### Bash
```sh
export RPKI_RTR_ENABLE_TLS=true
export RPKI_RTR_TCP_ADDR=0.0.0.0:3323
export RPKI_RTR_TLS_ADDR=0.0.0.0:3324
export RPKI_RTR_DB_PATH=./rtr-db
export RPKI_RTR_VRP_FILE=./data/vrps.txt
export RPKI_RTR_TLS_CERT_PATH=./certs/server-dns.crt
export RPKI_RTR_TLS_KEY_PATH=./certs/server-dns.key
export RPKI_RTR_TLS_CLIENT_CA_PATH=./certs/client-ca.crt
export RPKI_RTR_TCP_KEEPALIVE_SECS=60
export RPKI_RTR_WARN_INSECURE_TCP=true
export RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN=true
cargo run
```
A ready-to-edit example script is provided at
[`scripts/start-rtr-server.sh`](/C:/Users/xuxiu/git_code/rpki/scripts/start-rtr-server.sh).

4
data/vrps.txt Normal file
View File

@ -0,0 +1,4 @@
# prefix,max_len,asn
10.0.0.0/24,25,65001
10.0.1.0/24,24,65022
2001:db8::/32,64,65003

View File

@ -1,63 +1,235 @@
# rtr_debug_client # rtr_debug_client
`rtr_debug_client` 是一个用于调试和联调 RTRRPKI-to-Router服务端的小型客户端工具 `rtr_debug_client` 是一个轻量级的 RTR 调试客户端,用于手工联调和协议行为观察
它的目标不是做一个完整的生产级 router client而是提供一个简单、直接、可观察的调试入口用于 它适合以下场景:
- 在开发阶段验证 RTR server 的行为
- 发送 `Reset Query``Serial Query`
- 观察服务端返回的各类 PDU
- 检查会话状态、`session_id``serial` 的变化
- 排查 `ErrorReport``CacheReset``SerialNotify``RouterKey``ASPA`
- 联调纯 TCP 和 TLS 两种 RTR 传输方式
- 连接 RTR server 它不是生产级 router client而是一个便于调试和观察协议细节的小工具。
- 发送 `Reset Query``Serial Query`
- 接收并打印服务端返回的 PDU
- 辅助排查协议实现、会话状态、序列号增量、PDU 编码等问题
--- ## 当前支持的能力
## 适用场景
这个工具适合以下场景:
- 开发 RTR server 时做本地联调
- 验证服务端是否正确返回 `Cache Response`
- 检查 `IPv4 Prefix` / `IPv6 Prefix` / `ASPA` / `End of Data` 等 PDU
- 验证 `Serial Query` 路径是否正确
- 观察异常响应,例如 `Cache Reset``Error Report`
- 后续扩展为支持 TLS、自动断言、会话统计等调试能力
---
## 当前能力
当前版本支持: 当前版本支持:
- 纯 TCP 连接
- TCP 连接 RTR server - TLS 连接
- TLS 服务端证书校验
- 可选的 TLS 客户端证书认证
- 发送 `Reset Query` - 发送 `Reset Query`
- 发送 `Serial Query` - 发送 `Serial Query`
- 持续读取服务端返回的 PDU - 保持长连接持续接收服务端 PDU
- 解析并打印以下常见 PDU - 格式化展示以下 PDU
- `Serial Notify` - `Serial Notify`
- `Serial Query` - `Serial Query`
- `Reset Query` - `Reset Query`
- `Cache Response` - `Cache Response`
- `IPv4 Prefix` - `IPv4 Prefix`
- `IPv6 Prefix` - `IPv6 Prefix`
- `End of Data` - `Router Key`
- `Cache Reset` - `ASPA`
- `Error Report` - `End of Data`
- `ASPA` - `Cache Reset`
- 基础长度校验 - `Error Report`
- 最大 PDU 长度限制,防止异常数据导致过大内存分配 - 结构化展示 `ErrorReport`
- 错误码及语义名称
- encapsulated PDU 的 header 摘要
- encapsulated PDU 原始 hex
- arbitrary text 是否为 UTF-8
- arbitrary text 内容
- 根据 `EndOfData` 的 timing hint 自动轮询
- 收到 `ErrorReport` 后默认暂停自动轮询
- 通过 `--keep-after-error` 保持错误后的自动轮询
--- ## 构建
## 目录结构 ```sh
cargo build --bin rtr_debug_client
```
建议目录如下: ## 基本用法
```text 基本形式:
src/
└── bin/ ```sh
└── rtr_debug_client/ cargo run --bin rtr_debug_client -- <addr> <version> [reset|serial <session_id> <serial>] [options]
├── main.rs ```
├── protocol.rs
├── io.rs 默认值:
├── pretty.rs - `addr`: `127.0.0.1:3323`
└── README.md - `version`: `1`
- `mode`: `reset`
- `timeout`: `30`
- `poll`: `60`
## TCP 示例
发送 `Reset Query`
```sh
cargo run --bin rtr_debug_client -- 127.0.0.1:3323 1 reset
```
发送 `Serial Query`
```sh
cargo run --bin rtr_debug_client -- 127.0.0.1:3323 1 serial 42 100
```
持续观察错误路径:
```sh
cargo run --bin rtr_debug_client -- 127.0.0.1:3323 1 reset --keep-after-error
```
## TLS 示例
只做服务端证书校验:
```sh
cargo run --bin rtr_debug_client -- \
127.0.0.1:3324 1 reset \
--tls \
--ca-cert tests/fixtures/tls/client-ca.crt \
--server-name localhost
```
双向 TLS 认证:
```sh
cargo run --bin rtr_debug_client -- \
127.0.0.1:3324 1 reset \
--tls \
--ca-cert tests/fixtures/tls/client-ca.crt \
--server-name localhost \
--client-cert tests/fixtures/tls/client-good.crt \
--client-key tests/fixtures/tls/client-good.key
```
双向 TLS + 错误后继续自动轮询:
```sh
cargo run --bin rtr_debug_client -- \
127.0.0.1:3324 1 reset \
--tls \
--ca-cert tests/fixtures/tls/client-ca.crt \
--server-name localhost \
--client-cert tests/fixtures/tls/client-good.crt \
--client-key tests/fixtures/tls/client-good.key \
--keep-after-error
```
说明:
- 开启 `--tls` 时必须提供 `--ca-cert`
- 如果目标地址本身不适合直接作为 TLS 名称,显式提供 `--server-name`
- 客户端认证必须同时提供 `--client-cert``--client-key`
## 命令行参数
- `--tls`
使用 TLS 而不是纯 TCP。
- `--ca-cert <path>`
用于校验服务端证书的 CA 证书文件PEM 格式。
- `--server-name <name>`
TLS 握手时用于校验证书的服务端名称。
- `--client-cert <path>`
双向 TLS 时使用的客户端证书PEM 格式。
- `--client-key <path>`
`--client-cert` 配套的客户端私钥PEM 格式。
- `--timeout <secs>`
等待下一个 PDU 的读取超时时间,单位秒。
- `--poll <secs>`
在尚未拿到 `EndOfData` timing hint 前,默认使用的自动轮询间隔。
- `--keep-after-error`
收到 `ErrorReport` 后不暂停自动轮询。
## 运行中可用命令
程序启动后,可以在控制台输入以下命令:
- `help`
显示帮助。
- `state`
打印当前客户端状态。
- `reset`
发送 `Reset Query`
- `serial`
使用当前 `session_id``serial` 发送 `Serial Query`
- `serial <sid> <serial>`
使用显式参数发送 `Serial Query`
- `timeout`
查看当前读取超时设置。
- `timeout <secs>`
修改读取超时。
- `poll`
查看当前自动轮询间隔、轮询来源以及暂停状态。
- `poll <secs>`
手工覆盖当前轮询间隔。
- `poll pause`
暂停自动轮询。
- `poll resume`
恢复自动轮询。
- `keep-after-error`
查看当前是否启用了错误后持续轮询。
- `quit`
退出客户端。
## 自动轮询行为
客户端会保持连接,并周期性地向服务端发起下一次查询。
选择下一次轮询间隔的优先级如下:
1. `retry`,当最近一次 `ErrorReport``No Data Available``Transport Failure`
2. `refresh`,如果已经从 `EndOfData` 中拿到
3. 启动参数里的默认轮询间隔
收到 `ErrorReport` 后的默认行为:
- 默认暂停自动轮询
- 连接保持不关,方便继续观察
- 你可以手工输入 `reset``serial``poll resume` 继续
如果带了 `--keep-after-error`
- 收到 `ErrorReport` 后不会暂停
- 会继续按当前有效轮询间隔自动轮询
特殊情况:
- 当最近一次错误是 `No Data Available``Transport Failure` 时,恢复自动轮询后会优先参考 `retry`,而不是继续只看 `refresh`
## ErrorReport 展示内容
`ErrorReport` 会展示以下内容:
- 错误码及其语义名称
- encapsulated PDU 长度
- encapsulated PDU 的 header 摘要
- PDU 类型
- version
- length
- field1按类型解释为 `session_id``error_code`
- encapsulated PDU 原始 hex
- arbitrary text 长度
- arbitrary text 是否是 UTF-8
- arbitrary text 内容
这样在排查协议问题时,不需要先手工拆原始 hex就能快速知道是哪一个请求触发了错误。

View File

@ -1,10 +1,14 @@
use std::env; use std::env;
use std::io; use std::io;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, BufReader}; use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio::io::{self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::{timeout, Duration, Instant}; use tokio::time::{timeout, Duration, Instant};
use tokio_rustls::TlsConnector;
mod wire; mod wire;
mod pretty; mod pretty;
@ -19,15 +23,23 @@ use crate::protocol::{PduHeader, PduType, QueryMode};
const DEFAULT_READ_TIMEOUT_SECS: u64 = 30; const DEFAULT_READ_TIMEOUT_SECS: u64 = 30;
const DEFAULT_POLL_INTERVAL_SECS: u64 = 60; const DEFAULT_POLL_INTERVAL_SECS: u64 = 60;
trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
type DynStream = Box<dyn AsyncStream>;
type ClientWriter = WriteHalf<DynStream>;
#[tokio::main] #[tokio::main]
async fn main() -> io::Result<()> { async fn main() -> io::Result<()> {
let config = Config::from_args()?; let config = Config::from_args()?;
println!("== RTR debug client =="); println!("== RTR debug client ==");
println!("target : {}", config.addr); println!("target : {}", config.addr);
println!("transport: {}", config.transport.describe());
println!("version : {}", config.version); println!("version : {}", config.version);
println!("timeout : {}s", config.read_timeout_secs); println!("timeout : {}s", config.read_timeout_secs);
println!("poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs); println!("poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs);
println!("keep-after-error: {}", config.keep_after_error);
match &config.mode { match &config.mode {
QueryMode::Reset => { QueryMode::Reset => {
println!("mode : reset"); println!("mode : reset");
@ -41,14 +53,15 @@ async fn main() -> io::Result<()> {
println!(); println!();
print_help(); print_help();
let stream = TcpStream::connect(&config.addr).await?; let stream = connect_stream(&config).await?;
println!("connected to {}", config.addr); println!("connected to {}", config.addr);
let (mut reader, mut writer) = stream.into_split(); let (mut reader, mut writer) = tokio_io::split(stream);
let mut state = ClientState::new( let mut state = ClientState::new(
config.version, config.version,
config.read_timeout_secs, config.read_timeout_secs,
config.default_poll_secs, config.default_poll_secs,
config.keep_after_error,
); );
match config.mode { match config.mode {
@ -71,7 +84,7 @@ async fn main() -> io::Result<()> {
let mut stdin_lines = BufReader::new(stdin).lines(); let mut stdin_lines = BufReader::new(stdin).lines();
loop { loop {
let poll_sleep = tokio::time::sleep_until(state.next_poll_deadline); let poll_sleep = tokio::time::sleep_until(state.poll_deadline());
tokio::pin!(poll_sleep); tokio::pin!(poll_sleep);
tokio::select! { tokio::select! {
@ -131,7 +144,7 @@ async fn main() -> io::Result<()> {
} }
async fn handle_incoming_pdu( async fn handle_incoming_pdu(
writer: &mut OwnedWriteHalf, writer: &mut ClientWriter,
state: &mut ClientState, state: &mut ClientState,
header: &PduHeader, header: &PduHeader,
body: &[u8], body: &[u8],
@ -139,9 +152,10 @@ async fn handle_incoming_pdu(
match header.pdu_type() { match header.pdu_type() {
PduType::CacheResponse => { PduType::CacheResponse => {
state.current_session_id = Some(header.session_id()); state.current_session_id = Some(header.session_id());
state.last_error_code = None;
} }
PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::Aspa => { PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::RouterKey | PduType::Aspa => {
if state.current_session_id.is_none() { if state.current_session_id.is_none() {
state.current_session_id = Some(header.session_id()); state.current_session_id = Some(header.session_id());
} }
@ -161,6 +175,7 @@ async fn handle_incoming_pdu(
state.refresh = eod.refresh; state.refresh = eod.refresh;
state.retry = eod.retry; state.retry = eod.retry;
state.expire = eod.expire; state.expire = eod.expire;
state.last_error_code = None;
println!( println!(
"updated client state: session_id={}, serial={}", "updated client state: session_id={}, serial={}",
@ -198,6 +213,11 @@ async fn handle_incoming_pdu(
let notify_serial = parse_serial_notify_serial(body); let notify_serial = parse_serial_notify_serial(body);
println!(); println!();
println!(
"[notify] received Serial Notify: session_id={}, notify_serial={:?}",
notify_session_id,
notify_serial
);
match (state.session_id, state.serial, notify_serial) { match (state.session_id, state.serial, notify_serial) {
(Some(current_session_id), Some(current_serial), Some(_new_serial)) (Some(current_session_id), Some(current_serial), Some(_new_serial))
@ -233,6 +253,7 @@ async fn handle_incoming_pdu(
println!("received Cache Reset, send Reset Query"); println!("received Cache Reset, send Reset Query");
state.current_session_id = None; state.current_session_id = None;
state.serial = None; state.serial = None;
state.last_error_code = None;
send_reset_query(writer, state.version).await?; send_reset_query(writer, state.version).await?;
state.schedule_next_poll(); state.schedule_next_poll();
println!(); println!();
@ -240,9 +261,20 @@ async fn handle_incoming_pdu(
PduType::ErrorReport => { PduType::ErrorReport => {
println!(); println!();
println!("received Error Report, keep connection open for debugging."); println!("received Error Report, pause auto polling for debugging.");
state.last_error_code = Some(header.error_code());
if let Some(retry) = state.retry { if let Some(retry) = state.retry {
println!("will keep auto polling; server retry hint currently stored: {}s", retry); println!("server retry hint currently stored: {}s", retry);
if state.should_prefer_retry_poll() {
println!("when resumed, auto polling will use retry instead of refresh.");
}
}
if state.keep_after_error {
println!("keep-after-error is enabled, auto polling will continue.");
state.schedule_next_poll();
} else {
println!("use `reset`, `serial`, or `poll resume` to continue manually.");
state.pause_auto_poll();
} }
println!(); println!();
} }
@ -256,7 +288,7 @@ async fn handle_incoming_pdu(
} }
async fn handle_poll_tick( async fn handle_poll_tick(
writer: &mut OwnedWriteHalf, writer: &mut ClientWriter,
state: &mut ClientState, state: &mut ClientState,
) -> io::Result<()> { ) -> io::Result<()> {
println!(); println!();
@ -285,7 +317,7 @@ async fn handle_poll_tick(
async fn handle_console_command( async fn handle_console_command(
line: &str, line: &str,
writer: &mut OwnedWriteHalf, writer: &mut ClientWriter,
state: &mut ClientState, state: &mut ClientState,
) -> io::Result<bool> { ) -> io::Result<bool> {
let line = line.trim(); let line = line.trim();
@ -382,8 +414,24 @@ async fn handle_console_command(
"current effective poll interval: {}s", "current effective poll interval: {}s",
state.effective_poll_secs() state.effective_poll_secs()
); );
println!("poll interval source : {}", state.poll_interval_source());
println!("stored refresh hint : {:?}", state.refresh); println!("stored refresh hint : {:?}", state.refresh);
println!("default poll interval : {}s", state.default_poll_secs); println!("default poll interval : {}s", state.default_poll_secs);
println!("last_error_code : {:?}", state.last_error_code);
println!("auto polling paused : {}", state.poll_paused);
}
["poll", "pause"] => {
state.pause_auto_poll();
println!("auto polling paused");
}
["poll", "resume"] => {
state.resume_auto_poll();
println!(
"auto polling resumed, next poll scheduled after {}s",
state.effective_poll_secs()
);
} }
["poll", secs] => { ["poll", secs] => {
@ -428,6 +476,9 @@ fn print_help() {
println!(" timeout <secs> update read timeout seconds"); println!(" timeout <secs> update read timeout seconds");
println!(" poll show current poll interval"); println!(" poll show current poll interval");
println!(" poll <secs> override poll interval seconds"); println!(" poll <secs> override poll interval seconds");
println!(" poll pause pause auto polling");
println!(" poll resume resume auto polling");
println!(" keep-after-error show current keep-after-error setting");
println!(" quit exit client"); println!(" quit exit client");
println!(); println!();
} }
@ -444,6 +495,10 @@ fn print_state(state: &ClientState) {
println!(" read_timeout_secs : {}", state.read_timeout_secs); println!(" read_timeout_secs : {}", state.read_timeout_secs);
println!(" default_poll_secs : {}", state.default_poll_secs); println!(" default_poll_secs : {}", state.default_poll_secs);
println!(" effective_poll_secs: {}", state.effective_poll_secs()); println!(" effective_poll_secs: {}", state.effective_poll_secs());
println!(" poll_source : {}", state.poll_interval_source());
println!(" last_error_code : {:?}", state.last_error_code);
println!(" keep_after_error : {}", state.keep_after_error);
println!(" poll_paused : {}", state.poll_paused);
println!(); println!();
} }
@ -457,14 +512,22 @@ struct ClientState {
refresh: Option<u32>, refresh: Option<u32>,
retry: Option<u32>, retry: Option<u32>,
expire: Option<u32>, expire: Option<u32>,
last_error_code: Option<u16>,
keep_after_error: bool,
read_timeout_secs: u64, read_timeout_secs: u64,
default_poll_secs: u64, default_poll_secs: u64,
next_poll_deadline: Instant, next_poll_deadline: Instant,
poll_paused: bool,
} }
impl ClientState { impl ClientState {
fn new(version: u8, read_timeout_secs: u64, default_poll_secs: u64) -> Self { fn new(
version: u8,
read_timeout_secs: u64,
default_poll_secs: u64,
keep_after_error: bool,
) -> Self {
Self { Self {
version, version,
session_id: None, session_id: None,
@ -473,20 +536,60 @@ impl ClientState {
refresh: None, refresh: None,
retry: None, retry: None,
expire: None, expire: None,
last_error_code: None,
keep_after_error,
read_timeout_secs, read_timeout_secs,
default_poll_secs, default_poll_secs,
next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs), next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs),
poll_paused: false,
} }
} }
fn effective_poll_secs(&self) -> u64 { fn effective_poll_secs(&self) -> u64 {
self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs) if self.should_prefer_retry_poll() {
self.retry
.map(|v| v as u64)
.unwrap_or_else(|| self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs))
} else {
self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs)
}
} }
fn schedule_next_poll(&mut self) { fn schedule_next_poll(&mut self) {
self.next_poll_deadline = self.next_poll_deadline =
Instant::now() + Duration::from_secs(self.effective_poll_secs()); Instant::now() + Duration::from_secs(self.effective_poll_secs());
} }
fn pause_auto_poll(&mut self) {
self.poll_paused = true;
}
fn resume_auto_poll(&mut self) {
self.poll_paused = false;
self.schedule_next_poll();
}
fn poll_deadline(&self) -> Instant {
if self.poll_paused {
Instant::now() + Duration::from_secs(365 * 24 * 60 * 60)
} else {
self.next_poll_deadline
}
}
fn should_prefer_retry_poll(&self) -> bool {
matches!(self.last_error_code, Some(2 | 10))
}
fn poll_interval_source(&self) -> &'static str {
if self.should_prefer_retry_poll() && self.retry.is_some() {
"retry"
} else if self.refresh.is_some() {
"refresh"
} else {
"default"
}
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -496,17 +599,80 @@ struct Config {
mode: QueryMode, mode: QueryMode,
read_timeout_secs: u64, read_timeout_secs: u64,
default_poll_secs: u64, default_poll_secs: u64,
transport: TransportConfig,
keep_after_error: bool,
} }
impl Config { impl Config {
fn from_args() -> io::Result<Self> { fn from_args() -> io::Result<Self> {
let mut args = env::args().skip(1); let mut args = env::args().skip(1);
let mut positional = Vec::new();
let mut transport = TransportConfig::Tcp;
let mut read_timeout_secs = DEFAULT_READ_TIMEOUT_SECS;
let mut default_poll_secs = DEFAULT_POLL_INTERVAL_SECS;
let mut keep_after_error = false;
let addr = args while let Some(arg) = args.next() {
match arg.as_str() {
"--tls" => {
transport = TransportConfig::Tls(TlsConfig::default());
}
"--ca-cert" => {
let path = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--ca-cert requires a path")
})?;
ensure_tls_config(&mut transport)?.ca_cert = Some(PathBuf::from(path));
}
"--client-cert" => {
let path = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--client-cert requires a path")
})?;
ensure_tls_config(&mut transport)?.client_cert = Some(PathBuf::from(path));
}
"--client-key" => {
let path = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--client-key requires a path")
})?;
ensure_tls_config(&mut transport)?.client_key = Some(PathBuf::from(path));
}
"--server-name" => {
let name = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--server-name requires a value")
})?;
ensure_tls_config(&mut transport)?.server_name = Some(name);
}
"--timeout" => {
let secs = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--timeout requires seconds")
})?;
read_timeout_secs = parse_u64_arg(&secs, "--timeout")?;
}
"--poll" => {
let secs = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--poll requires seconds")
})?;
default_poll_secs = parse_u64_arg(&secs, "--poll")?;
}
"--keep-after-error" => {
keep_after_error = true;
}
_ if arg.starts_with("--") => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown option '{}'", arg),
));
}
_ => positional.push(arg),
}
}
let mut positional = positional.into_iter();
let addr = positional
.next() .next()
.unwrap_or_else(|| "127.0.0.1:3323".to_string()); .unwrap_or_else(|| "127.0.0.1:3323".to_string());
let version = args let version = positional
.next() .next()
.map(|s| { .map(|s| {
s.parse::<u8>().map_err(|e| { s.parse::<u8>().map_err(|e| {
@ -526,10 +692,10 @@ impl Config {
)); ));
} }
let mode = match args.next().as_deref() { let mode = match positional.next().as_deref() {
None | Some("reset") => QueryMode::Reset, None | Some("reset") => QueryMode::Reset,
Some("serial") => { Some("serial") => {
let session_id = args let session_id = positional
.next() .next()
.ok_or_else(|| { .ok_or_else(|| {
io::Error::new( io::Error::new(
@ -545,7 +711,7 @@ impl Config {
) )
})?; })?;
let serial = args let serial = positional
.next() .next()
.ok_or_else(|| { .ok_or_else(|| {
io::Error::new( io::Error::new(
@ -571,12 +737,212 @@ impl Config {
} }
}; };
let transport = finalize_transport(transport, &addr)?;
Ok(Self { Ok(Self {
addr, addr,
version, version,
mode, mode,
read_timeout_secs: DEFAULT_READ_TIMEOUT_SECS, read_timeout_secs,
default_poll_secs: DEFAULT_POLL_INTERVAL_SECS, default_poll_secs,
transport,
keep_after_error,
}) })
} }
} }
#[derive(Debug, Clone)]
enum TransportConfig {
Tcp,
Tls(TlsConfig),
}
impl TransportConfig {
fn describe(&self) -> String {
match self {
Self::Tcp => "tcp".to_string(),
Self::Tls(cfg) => format!(
"tls (server_name={}, ca_cert={}, client_cert={})",
cfg.server_name.as_deref().unwrap_or("<unset>"),
cfg.ca_cert
.as_ref()
.map(|path| path.display().to_string())
.unwrap_or_else(|| "<unset>".to_string()),
cfg.client_cert
.as_ref()
.map(|path| path.display().to_string())
.unwrap_or_else(|| "<none>".to_string())
),
}
}
}
#[derive(Debug, Clone, Default)]
struct TlsConfig {
server_name: Option<String>,
ca_cert: Option<PathBuf>,
client_cert: Option<PathBuf>,
client_key: Option<PathBuf>,
}
fn ensure_tls_config(transport: &mut TransportConfig) -> io::Result<&mut TlsConfig> {
if matches!(transport, TransportConfig::Tcp) {
*transport = TransportConfig::Tls(TlsConfig::default());
}
match transport {
TransportConfig::Tls(cfg) => Ok(cfg),
TransportConfig::Tcp => unreachable!(),
}
}
fn finalize_transport(transport: TransportConfig, addr: &str) -> io::Result<TransportConfig> {
match transport {
TransportConfig::Tcp => Ok(TransportConfig::Tcp),
TransportConfig::Tls(mut cfg) => {
let ca_cert = cfg.ca_cert.take().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"TLS mode requires --ca-cert <path>",
)
})?;
match (&cfg.client_cert, &cfg.client_key) {
(Some(_), Some(_)) | (None, None) => {}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS client authentication requires both --client-cert and --client-key",
));
}
}
let server_name = cfg
.server_name
.take()
.or_else(|| default_server_name_for_addr(addr))
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"TLS mode requires --server-name or an address with a parsable host",
)
})?;
Ok(TransportConfig::Tls(TlsConfig {
server_name: Some(server_name),
ca_cert: Some(ca_cert),
client_cert: cfg.client_cert,
client_key: cfg.client_key,
}))
}
}
}
async fn connect_stream(config: &Config) -> io::Result<DynStream> {
match &config.transport {
TransportConfig::Tcp => Ok(Box::new(TcpStream::connect(&config.addr).await?)),
TransportConfig::Tls(tls) => connect_tls_stream(&config.addr, tls).await,
}
}
async fn connect_tls_stream(addr: &str, tls: &TlsConfig) -> io::Result<DynStream> {
let stream = TcpStream::connect(addr).await?;
let connector = build_tls_connector(tls)?;
let server_name_str = tls
.server_name
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS server name"))?;
let server_name = ServerName::try_from(server_name_str.clone()).map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid TLS server name '{}': {}", server_name_str, err),
)
})?;
let tls_stream = connector.connect(server_name, stream).await.map_err(|err| {
io::Error::new(io::ErrorKind::ConnectionAborted, format!("TLS handshake failed: {}", err))
})?;
Ok(Box::new(tls_stream))
}
fn build_tls_connector(tls: &TlsConfig) -> io::Result<TlsConnector> {
let ca_cert_path = tls
.ca_cert
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS CA cert"))?;
let ca_certs = load_certs(ca_cert_path)?;
let mut roots = RootCertStore::empty();
let (added, _ignored) = roots.add_parsable_certificates(ca_certs);
if added == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("no valid CA certificates found in {}", ca_cert_path.display()),
));
}
let builder = RustlsClientConfig::builder().with_root_certificates(roots);
let client_config = match (&tls.client_cert, &tls.client_key) {
(Some(cert_path), Some(key_path)) => {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
builder.with_client_auth_cert(certs, key).map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid TLS client certificate/key: {}", err),
)
})?
}
(None, None) => builder.with_no_client_auth(),
_ => unreachable!(),
};
Ok(TlsConnector::from(Arc::new(client_config)))
}
fn load_certs(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
let mut reader = std::io::BufReader::new(std::fs::File::open(path)?);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
if certs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("no certificates found in {}", path.display()),
));
}
Ok(certs)
}
fn load_private_key(path: &Path) -> io::Result<PrivateKeyDer<'static>> {
let mut reader = std::io::BufReader::new(std::fs::File::open(path)?);
rustls_pemfile::private_key(&mut reader)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("no private key found in {}", path.display()),
)
})
}
fn default_server_name_for_addr(addr: &str) -> Option<String> {
if let Some(rest) = addr.strip_prefix('[') {
return rest.split(']').next().map(str::to_string);
}
addr.rsplit_once(':').map(|(host, _port)| host.to_string())
}
fn parse_u64_arg(value: &str, name: &str) -> io::Result<u64> {
let parsed = value.parse::<u64>().map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid value for {}: {}", name, err),
)
})?;
if parsed == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("{} must be > 0", name),
));
}
Ok(parsed)
}

View File

@ -25,6 +25,9 @@ pub fn print_pdu(header: &PduHeader, body: &[u8]) {
PduType::Ipv6Prefix => { PduType::Ipv6Prefix => {
print_ipv6_prefix(header, body); print_ipv6_prefix(header, body);
} }
PduType::RouterKey => {
print_router_key(header, body);
}
PduType::EndOfData => { PduType::EndOfData => {
print_end_of_data(header, body); print_end_of_data(header, body);
} }
@ -128,7 +131,11 @@ fn print_end_of_data(header: &PduHeader, body: &[u8]) {
} }
fn print_error_report(header: &PduHeader, body: &[u8]) { fn print_error_report(header: &PduHeader, body: &[u8]) {
println!("error_code : {}", header.error_code()); println!(
"error_code : {} ({})",
header.error_code(),
error_code_name(header.error_code())
);
if body.len() < 8 { if body.len() < 8 {
println!("invalid ErrorReport body length: {}", body.len()); println!("invalid ErrorReport body length: {}", body.len());
@ -165,8 +172,27 @@ fn print_error_report(header: &PduHeader, body: &[u8]) {
let text = String::from_utf8_lossy(text_bytes); let text = String::from_utf8_lossy(text_bytes);
println!("encap_len : {}", encapsulated_len); println!("encap_len : {}", encapsulated_len);
if let Some(encap_header) = parse_encapsulated_header(encapsulated) {
println!("encap_pdu_type : {}", encap_header.pdu_type());
println!("encap_version : {}", encap_header.version);
println!("encap_length : {}", encap_header.length);
match encap_header.pdu_type() {
PduType::ErrorReport => {
println!("encap_field1 : error_code={}", encap_header.error_code());
}
PduType::Unknown(_) => {
println!("encap_field1 : {}", encap_header.field1);
}
_ => {
println!("encap_field1 : session_id={}", encap_header.session_id());
}
}
} else if encapsulated_len > 0 {
println!("encap_header : <truncated or unavailable>");
}
println!("encap_pdu : {}", hex_bytes(encapsulated)); println!("encap_pdu : {}", hex_bytes(encapsulated));
println!("text_len : {}", text_len); println!("text_len : {}", text_len);
println!("text_utf8 : {}", std::str::from_utf8(text_bytes).is_ok());
println!("text : {}", text); println!("text : {}", text);
} }
@ -194,7 +220,6 @@ fn print_serial_query(header: &PduHeader, body: &[u8]) {
println!("serial : {}", serial); println!("serial : {}", serial);
} }
#[allow(dead_code)]
fn print_router_key(header: &PduHeader, body: &[u8]) { fn print_router_key(header: &PduHeader, body: &[u8]) {
println!("session_id : {}", header.session_id()); println!("session_id : {}", header.session_id());
@ -218,6 +243,34 @@ fn print_router_key(header: &PduHeader, body: &[u8]) {
println!("spki : {}", hex_bytes(spki)); println!("spki : {}", hex_bytes(spki));
} }
fn error_code_name(code: u16) -> &'static str {
match code {
0 => "Corrupt Data",
1 => "Internal Error",
2 => "No Data Available",
3 => "Invalid Request",
4 => "Unsupported Protocol Version",
5 => "Unsupported PDU Type",
6 => "Withdrawal of Unknown Record",
7 => "Duplicate Announcement Received",
8 => "Unexpected Protocol Version",
9 => "ASPA Provider List Error",
10 => "Transport Failure",
11 => "Ordering Error",
_ => "Unknown Error Code",
}
}
fn parse_encapsulated_header(encapsulated: &[u8]) -> Option<PduHeader> {
if encapsulated.len() < 8 {
return None;
}
let mut header = [0u8; 8];
header.copy_from_slice(&encapsulated[..8]);
Some(PduHeader::from_bytes(header))
}
fn print_aspa(header: &PduHeader, body: &[u8]) { fn print_aspa(header: &PduHeader, body: &[u8]) {
println!("session_id : {}", header.session_id()); println!("session_id : {}", header.session_id());

View File

@ -25,6 +25,7 @@ pub enum PduType {
CacheResponse, CacheResponse,
Ipv4Prefix, Ipv4Prefix,
Ipv6Prefix, Ipv6Prefix,
RouterKey,
EndOfData, EndOfData,
CacheReset, CacheReset,
ErrorReport, ErrorReport,
@ -41,6 +42,7 @@ impl PduType {
Self::CacheResponse => 3, Self::CacheResponse => 3,
Self::Ipv4Prefix => 4, Self::Ipv4Prefix => 4,
Self::Ipv6Prefix => 6, Self::Ipv6Prefix => 6,
Self::RouterKey => 9,
Self::EndOfData => 7, Self::EndOfData => 7,
Self::CacheReset => 8, Self::CacheReset => 8,
Self::ErrorReport => 10, Self::ErrorReport => 10,
@ -57,6 +59,7 @@ impl PduType {
Self::CacheResponse => "Cache Response", Self::CacheResponse => "Cache Response",
Self::Ipv4Prefix => "IPv4 Prefix", Self::Ipv4Prefix => "IPv4 Prefix",
Self::Ipv6Prefix => "IPv6 Prefix", Self::Ipv6Prefix => "IPv6 Prefix",
Self::RouterKey => "Router Key",
Self::EndOfData => "End of Data", Self::EndOfData => "End of Data",
Self::CacheReset => "Cache Reset", Self::CacheReset => "Cache Reset",
Self::ErrorReport => "Error Report", Self::ErrorReport => "Error Report",
@ -75,6 +78,7 @@ impl From<u8> for PduType {
3 => Self::CacheResponse, 3 => Self::CacheResponse,
4 => Self::Ipv4Prefix, 4 => Self::Ipv4Prefix,
6 => Self::Ipv6Prefix, 6 => Self::Ipv6Prefix,
9 => Self::RouterKey,
7 => Self::EndOfData, 7 => Self::EndOfData,
8 => Self::CacheReset, 8 => Self::CacheReset,
10 => Self::ErrorReport, 10 => Self::ErrorReport,
@ -152,4 +156,4 @@ pub fn hex_bytes(data: &[u8]) -> String {
let _ = write!(out, "{:02x}", b); let _ = write!(out, "{:02x}", b);
} }
out out
} }

View File

@ -1,3 +1,4 @@
use std::env;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
@ -10,7 +11,7 @@ use rpki::rtr::cache::{RtrCache, SharedRtrCache};
use rpki::rtr::loader::load_vrps_from_file; use rpki::rtr::loader::load_vrps_from_file;
use rpki::rtr::payload::Timing; use rpki::rtr::payload::Timing;
use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService}; use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService};
use rpki::rtr::store_db::RtrStore; use rpki::rtr::store::RtrStore;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct AppConfig { struct AppConfig {
@ -22,6 +23,7 @@ struct AppConfig {
vrp_file: String, vrp_file: String,
tls_cert_path: String, tls_cert_path: String,
tls_key_path: String, tls_key_path: String,
tls_client_ca_path: String,
max_delta: u8, max_delta: u8,
refresh_interval: Duration, refresh_interval: Duration,
@ -40,23 +42,107 @@ impl Default for AppConfig {
vrp_file: r"C:\Users\xuxiu\git_code\rpki\data\vrps.txt".to_string(), vrp_file: r"C:\Users\xuxiu\git_code\rpki\data\vrps.txt".to_string(),
tls_cert_path: "./certs/server.crt".to_string(), tls_cert_path: "./certs/server.crt".to_string(),
tls_key_path: "./certs/server.key".to_string(), tls_key_path: "./certs/server.key".to_string(),
tls_client_ca_path: "./certs/client-ca.crt".to_string(),
max_delta: 100, max_delta: 100,
refresh_interval: Duration::from_secs(300), refresh_interval: Duration::from_secs(10),
service_config: RtrServiceConfig { service_config: RtrServiceConfig {
max_connections: 512, max_connections: 512,
notify_queue_size: 1024, notify_queue_size: 1024,
tcp_keepalive: Some(Duration::from_secs(60)),
warn_insecure_tcp: true,
require_tls_server_dns_name_san: false,
}, },
} }
} }
} }
impl AppConfig {
fn from_env() -> Result<Self> {
let mut config = Self::default();
if let Some(value) = env_var("RPKI_RTR_ENABLE_TLS")? {
config.enable_tls = parse_bool(&value, "RPKI_RTR_ENABLE_TLS")?;
}
if let Some(value) = env_var("RPKI_RTR_TCP_ADDR")? {
config.tcp_addr = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_TCP_ADDR '{}': {}", value, err))?;
}
if let Some(value) = env_var("RPKI_RTR_TLS_ADDR")? {
config.tls_addr = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_TLS_ADDR '{}': {}", value, err))?;
}
if let Some(value) = env_var("RPKI_RTR_DB_PATH")? {
config.db_path = value;
}
if let Some(value) = env_var("RPKI_RTR_VRP_FILE")? {
config.vrp_file = value;
}
if let Some(value) = env_var("RPKI_RTR_TLS_CERT_PATH")? {
config.tls_cert_path = value;
}
if let Some(value) = env_var("RPKI_RTR_TLS_KEY_PATH")? {
config.tls_key_path = value;
}
if let Some(value) = env_var("RPKI_RTR_TLS_CLIENT_CA_PATH")? {
config.tls_client_ca_path = value;
}
if let Some(value) = env_var("RPKI_RTR_MAX_DELTA")? {
config.max_delta = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_MAX_DELTA '{}': {}", value, err))?;
}
if let Some(value) = env_var("RPKI_RTR_REFRESH_INTERVAL_SECS")? {
let secs: u64 = value.parse().map_err(|err| {
anyhow!(
"invalid RPKI_RTR_REFRESH_INTERVAL_SECS '{}': {}",
value,
err
)
})?;
config.refresh_interval = Duration::from_secs(secs);
}
if let Some(value) = env_var("RPKI_RTR_MAX_CONNECTIONS")? {
config.service_config.max_connections = value.parse().map_err(|err| {
anyhow!("invalid RPKI_RTR_MAX_CONNECTIONS '{}': {}", value, err)
})?;
}
if let Some(value) = env_var("RPKI_RTR_NOTIFY_QUEUE_SIZE")? {
config.service_config.notify_queue_size = value.parse().map_err(|err| {
anyhow!("invalid RPKI_RTR_NOTIFY_QUEUE_SIZE '{}': {}", value, err)
})?;
}
if let Some(value) = env_var("RPKI_RTR_TCP_KEEPALIVE_SECS")? {
let secs: u64 = value.parse().map_err(|err| {
anyhow!("invalid RPKI_RTR_TCP_KEEPALIVE_SECS '{}': {}", value, err)
})?;
config.service_config.tcp_keepalive = if secs == 0 {
None
} else {
Some(Duration::from_secs(secs))
};
}
if let Some(value) = env_var("RPKI_RTR_WARN_INSECURE_TCP")? {
config.service_config.warn_insecure_tcp =
parse_bool(&value, "RPKI_RTR_WARN_INSECURE_TCP")?;
}
if let Some(value) = env_var("RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN")? {
config.service_config.require_tls_server_dns_name_san =
parse_bool(&value, "RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN")?;
}
Ok(config)
}
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
init_tracing(); init_tracing();
let config = AppConfig::default(); let config = AppConfig::from_env()?;
log_startup_config(&config); log_startup_config(&config);
let store = open_store(&config)?; let store = open_store(&config)?;
@ -101,8 +187,8 @@ fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result<SharedRtrCa
.map_err(|_| anyhow!("cache read lock poisoned during startup"))?; .map_err(|_| anyhow!("cache read lock poisoned during startup"))?;
info!( info!(
"cache initialized: session_id={}, serial={}", "cache initialized: session_ids={:?}, serial={}",
cache.session_id(), cache.session_ids(),
cache.serial() cache.serial()
); );
} }
@ -118,6 +204,7 @@ fn start_servers(config: &AppConfig, service: &RtrService) -> RunningRtrService
config.tls_addr, config.tls_addr,
&config.tls_cert_path, &config.tls_cert_path,
&config.tls_key_path, &config.tls_key_path,
&config.tls_client_ca_path,
) )
} else { } else {
info!("starting TCP RTR server"); info!("starting TCP RTR server");
@ -142,6 +229,7 @@ fn spawn_refresh_task(
match load_vrps_from_file(&vrp_file) { match load_vrps_from_file(&vrp_file) {
Ok(payloads) => { Ok(payloads) => {
let payload_count = payloads.len();
let updated = { let updated = {
let mut cache = match shared_cache.write() { let mut cache = match shared_cache.write() {
Ok(guard) => guard, Ok(guard) => guard,
@ -154,7 +242,27 @@ fn spawn_refresh_task(
let old_serial = cache.serial(); let old_serial = cache.serial();
match cache.update(payloads, &store) { match cache.update(payloads, &store) {
Ok(()) => cache.serial() != old_serial, Ok(()) => {
let new_serial = cache.serial();
if new_serial != old_serial {
info!(
"RTR cache refresh applied: vrp_file={}, payload_count={}, old_serial={}, new_serial={}",
vrp_file,
payload_count,
old_serial,
new_serial
);
true
} else {
info!(
"RTR cache refresh found no change: vrp_file={}, payload_count={}, serial={}",
vrp_file,
payload_count,
old_serial
);
false
}
}
Err(err) => { Err(err) => {
warn!("RTR cache update failed: {:?}", err); warn!("RTR cache update failed: {:?}", err);
false false
@ -191,6 +299,7 @@ fn log_startup_config(config: &AppConfig) {
info!("tls_addr={}", config.tls_addr); info!("tls_addr={}", config.tls_addr);
info!("tls_cert_path={}", config.tls_cert_path); info!("tls_cert_path={}", config.tls_cert_path);
info!("tls_key_path={}", config.tls_key_path); info!("tls_key_path={}", config.tls_key_path);
info!("tls_client_ca_path={}", config.tls_client_ca_path);
} }
info!("vrp_file={}", config.vrp_file); info!("vrp_file={}", config.vrp_file);
@ -207,6 +316,22 @@ fn log_startup_config(config: &AppConfig) {
"notify_queue_size={}", "notify_queue_size={}",
config.service_config.notify_queue_size config.service_config.notify_queue_size
); );
info!(
"tcp_keepalive_secs={}",
config
.service_config
.tcp_keepalive
.map(|duration| duration.as_secs().to_string())
.unwrap_or_else(|| "disabled".to_string())
);
info!(
"warn_insecure_tcp={}",
config.service_config.warn_insecure_tcp
);
info!(
"require_tls_server_dns_name_san={}",
config.service_config.require_tls_server_dns_name_san
);
} }
fn init_tracing() { fn init_tracing() {
@ -215,4 +340,20 @@ fn init_tracing() {
.with_thread_ids(true) .with_thread_ids(true)
.with_level(true) .with_level(true)
.try_init(); .try_init();
} }
fn env_var(name: &str) -> Result<Option<String>> {
match env::var(name) {
Ok(value) => Ok(Some(value)),
Err(env::VarError::NotPresent) => Ok(None),
Err(err) => Err(anyhow!("failed to read {}: {}", name, err)),
}
}
fn parse_bool(value: &str, name: &str) -> Result<bool> {
match value.trim().to_ascii_lowercase().as_str() {
"1" | "true" | "yes" | "on" => Ok(true),
"0" | "false" | "no" | "off" => Ok(false),
_ => Err(anyhow!("invalid {} '{}': expected boolean", name, value)),
}
}

View File

@ -1,929 +0,0 @@
use std::cmp::Ordering;
use std::collections::{BTreeSet, VecDeque};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::data_model::resources::ip_resources::IPAddress;
use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Timing};
use crate::rtr::store_db::RtrStore;
const DEFAULT_RETRY_INTERVAL: Duration = Duration::from_secs(600);
const DEFAULT_EXPIRE_INTERVAL: Duration = Duration::from_secs(7200);
pub type SharedRtrCache = Arc<RwLock<RtrCache>>;
#[derive(Debug, Clone)]
pub struct DualTime {
instant: Instant,
utc: DateTime<Utc>,
}
impl DualTime {
/// Create current time.
pub fn now() -> Self {
Self {
instant: Instant::now(),
utc: Utc::now(),
}
}
/// Get UTC time for logs.
pub fn utc(&self) -> DateTime<Utc> {
self.utc
}
/// Elapsed duration since creation/reset.
pub fn elapsed(&self) -> Duration {
self.instant.elapsed()
}
/// Whether duration is expired.
pub fn is_expired(&self, duration: Duration) -> bool {
self.elapsed() >= duration
}
/// Reset to now.
pub fn reset(&mut self) {
self.instant = Instant::now();
self.utc = Utc::now();
}
}
impl Serialize for DualTime {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.utc.timestamp_millis().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for DualTime {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let millis = i64::deserialize(deserializer)?;
let naive = NaiveDateTime::from_timestamp_millis(millis)
.ok_or_else(|| serde::de::Error::custom("invalid timestamp"))?;
let utc = DateTime::<Utc>::from_utc(naive, Utc);
Ok(Self {
instant: Instant::now(),
utc,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Snapshot {
origins: BTreeSet<RouteOrigin>,
router_keys: BTreeSet<RouterKey>,
aspas: BTreeSet<Aspa>,
created_at: DualTime,
origins_hash: [u8; 32],
router_keys_hash: [u8; 32],
aspas_hash: [u8; 32],
snapshot_hash: [u8; 32],
}
impl Snapshot {
pub fn new(
origins: BTreeSet<RouteOrigin>,
router_keys: BTreeSet<RouterKey>,
aspas: BTreeSet<Aspa>,
) -> Self {
let mut snapshot = Snapshot {
origins,
router_keys,
aspas,
created_at: DualTime::now(),
origins_hash: [0u8; 32],
router_keys_hash: [0u8; 32],
aspas_hash: [0u8; 32],
snapshot_hash: [0u8; 32],
};
snapshot.recompute_hashes();
snapshot
}
pub fn empty() -> Self {
Self::new(BTreeSet::new(), BTreeSet::new(), BTreeSet::new())
}
pub fn from_payloads(payloads: Vec<Payload>) -> Self {
let mut origins = BTreeSet::new();
let mut router_keys = BTreeSet::new();
let mut aspas = BTreeSet::new();
for p in payloads {
match p {
Payload::RouteOrigin(o) => {
origins.insert(o);
}
Payload::RouterKey(k) => {
router_keys.insert(k);
}
Payload::Aspa(a) => {
aspas.insert(a);
}
}
}
Snapshot::new(origins, router_keys, aspas)
}
pub fn recompute_hashes(&mut self) {
self.origins_hash = self.compute_origins_hash();
self.router_keys_hash = self.compute_router_keys_hash();
self.aspas_hash = self.compute_aspas_hash();
self.snapshot_hash = self.compute_snapshot_hash();
}
fn compute_origins_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.origins.iter())
}
fn compute_router_keys_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.router_keys.iter())
}
fn compute_aspas_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.aspas.iter())
}
fn compute_snapshot_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"snapshot:v1");
hasher.update(self.origins_hash);
hasher.update(self.router_keys_hash);
hasher.update(self.aspas_hash);
hasher.finalize().into()
}
fn hash_ordered_iter<'a, T, I>(iter: I) -> [u8; 32]
where
T: Serialize + 'a,
I: IntoIterator<Item = &'a T>,
{
let mut hasher = Sha256::new();
hasher.update(b"set:v1");
for item in iter {
let encoded =
serde_json::to_vec(item).expect("serialize snapshot item for hashing failed");
let len = (encoded.len() as u32).to_be_bytes();
hasher.update(len);
hasher.update(encoded);
}
hasher.finalize().into()
}
pub fn diff(&self, new_snapshot: &Snapshot) -> (Vec<Payload>, Vec<Payload>) {
let mut announced = Vec::new();
let mut withdrawn = Vec::new();
if !self.same_origins(new_snapshot) {
for origin in new_snapshot.origins.difference(&self.origins) {
announced.push(Payload::RouteOrigin(origin.clone()));
}
for origin in self.origins.difference(&new_snapshot.origins) {
withdrawn.push(Payload::RouteOrigin(origin.clone()));
}
}
if !self.same_router_keys(new_snapshot) {
for key in new_snapshot.router_keys.difference(&self.router_keys) {
announced.push(Payload::RouterKey(key.clone()));
}
for key in self.router_keys.difference(&new_snapshot.router_keys) {
withdrawn.push(Payload::RouterKey(key.clone()));
}
}
if !self.same_aspas(new_snapshot) {
for aspa in new_snapshot.aspas.difference(&self.aspas) {
announced.push(Payload::Aspa(aspa.clone()));
}
for aspa in self.aspas.difference(&new_snapshot.aspas) {
withdrawn.push(Payload::Aspa(aspa.clone()));
}
}
(announced, withdrawn)
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
pub fn payloads(&self) -> Vec<Payload> {
let mut v = Vec::with_capacity(
self.origins.len() + self.router_keys.len() + self.aspas.len(),
);
v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin));
v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey));
v.extend(self.aspas.iter().cloned().map(Payload::Aspa));
v
}
/// Payloads sorted for RTR full snapshot sending.
/// Snapshot represents current valid state, so all payloads are treated as announcements.
pub fn payloads_for_rtr(&self) -> Vec<Payload> {
let mut payloads = self.payloads();
sort_payloads_for_rtr(&mut payloads, true);
payloads
}
pub fn origins_hash(&self) -> [u8; 32] {
self.origins_hash
}
pub fn router_keys_hash(&self) -> [u8; 32] {
self.router_keys_hash
}
pub fn aspas_hash(&self) -> [u8; 32] {
self.aspas_hash
}
pub fn snapshot_hash(&self) -> [u8; 32] {
self.snapshot_hash
}
pub fn same_origins(&self, other: &Self) -> bool {
self.origins_hash == other.origins_hash
}
pub fn same_router_keys(&self, other: &Self) -> bool {
self.router_keys_hash == other.router_keys_hash
}
pub fn same_aspas(&self, other: &Self) -> bool {
self.aspas_hash == other.aspas_hash
}
pub fn same_content(&self, other: &Self) -> bool {
self.snapshot_hash == other.snapshot_hash
}
pub fn origins(&self) -> &BTreeSet<RouteOrigin> {
&self.origins
}
pub fn router_keys(&self) -> &BTreeSet<RouterKey> {
&self.router_keys
}
pub fn aspas(&self) -> &BTreeSet<Aspa> {
&self.aspas
}
}
impl Snapshot {
pub fn is_empty(&self) -> bool {
self.origins.is_empty()
&& self.router_keys.is_empty()
&& self.aspas.is_empty()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Delta {
serial: u32,
announced: Vec<Payload>,
withdrawn: Vec<Payload>,
created_at: DualTime,
}
impl Delta {
pub fn new(serial: u32, mut announced: Vec<Payload>, mut withdrawn: Vec<Payload>) -> Self {
sort_payloads_for_rtr(&mut announced, true);
sort_payloads_for_rtr(&mut withdrawn, false);
Delta {
serial,
announced,
withdrawn,
created_at: DualTime::now(),
}
}
pub fn serial(&self) -> u32 {
self.serial
}
pub fn announced(&self) -> &[Payload] {
&self.announced
}
pub fn withdrawn(&self) -> &[Payload] {
&self.withdrawn
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
}
#[derive(Debug)]
pub struct RtrCache {
// Session ID created at cache startup.
session_id: u16,
// Current serial.
pub serial: u32,
// Full snapshot.
pub snapshot: Snapshot,
// Delta window.
deltas: VecDeque<Arc<Delta>>,
// Max number of deltas to keep.
max_delta: u8,
// Refresh interval.
timing: Timing,
// Last update begin time.
last_update_begin: DualTime,
// Last update end time.
last_update_end: DualTime,
// Cache created time.
created_at: DualTime,
}
impl Default for RtrCache {
fn default() -> Self {
let now = DualTime::now();
Self {
session_id: rand::random(),
serial: 0,
snapshot: Snapshot::empty(),
deltas: VecDeque::with_capacity(100),
max_delta: 100,
timing: Timing::default(),
last_update_begin: now.clone(),
last_update_end: now.clone(),
created_at: now,
}
}
}
pub struct RtrCacheBuilder {
session_id: Option<u16>,
max_delta: Option<u8>,
timing: Option<Timing>,
serial: Option<u32>,
snapshot: Option<Snapshot>,
deltas: Option<VecDeque<Arc<Delta>>>,
created_at: Option<DualTime>,
}
impl RtrCacheBuilder {
pub fn new() -> Self {
Self {
session_id: None,
max_delta: None,
timing: None,
serial: None,
snapshot: None,
deltas: None,
created_at: None,
}
}
pub fn session_id(mut self, v: u16) -> Self {
self.session_id = Some(v);
self
}
pub fn max_delta(mut self, v: u8) -> Self {
self.max_delta = Some(v);
self
}
pub fn timing(mut self, v: Timing) -> Self {
self.timing = Some(v);
self
}
pub fn serial(mut self, v: u32) -> Self {
self.serial = Some(v);
self
}
pub fn snapshot(mut self, v: Snapshot) -> Self {
self.snapshot = Some(v);
self
}
pub fn deltas(mut self, v: VecDeque<Arc<Delta>>) -> Self {
self.deltas = Some(v);
self
}
pub fn created_at(mut self, v: DualTime) -> Self {
self.created_at = Some(v);
self
}
pub fn build(self) -> RtrCache {
let now = DualTime::now();
let max_delta = self.max_delta.unwrap_or(100);
let timing = self.timing.unwrap_or_default();
let snapshot = self.snapshot.unwrap_or_else(Snapshot::empty);
let deltas = self
.deltas
.unwrap_or_else(|| VecDeque::with_capacity(max_delta.into()));
let serial = self.serial.unwrap_or(0);
let created_at = self.created_at.unwrap_or_else(|| now.clone());
let session_id = self.session_id.unwrap_or_else(rand::random);
RtrCache {
session_id,
serial,
snapshot,
deltas,
max_delta,
timing,
last_update_begin: now.clone(),
last_update_end: now,
created_at,
}
}
}
impl RtrCache {
/// Initialize cache from DB if possible; otherwise from file loader.
pub fn init(
self,
store: &RtrStore,
max_delta: u8,
timing: Timing,
file_loader: impl Fn() -> anyhow::Result<Vec<Payload>>,
) -> anyhow::Result<Self> {
if let Some(cache) = Self::try_restore_from_store(store, max_delta, timing)? {
tracing::info!(
"RTR cache restored from store: session_id={}, serial={}",
self.session_id,
self.serial
);
return Ok(cache);
}
tracing::warn!("RTR cache store unavailable or invalid, fallback to file loader");
let payloads = file_loader()?;
let snapshot = Snapshot::from_payloads(payloads);
if snapshot.is_empty() {
anyhow::bail!("file loader returned an empty snapshot");
}
tracing::info!(
"RTR cache initialized from file loader: session_id={}, serial={}",
self.session_id,
self.serial
);
let serial = 1;
let session_id: u16 = rand::random();
let snapshot_for_store = snapshot.clone();
let snapshot_for_cache = snapshot.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) =
store.save_snapshot_and_meta(&snapshot_for_store, session_id, serial)
{
tracing::error!("persist failed: {:?}", e);
}
});
Ok(RtrCacheBuilder::new()
.session_id(session_id)
.max_delta(max_delta)
.timing(timing)
.serial(serial)
.snapshot(snapshot_for_cache)
.build())
}
fn try_restore_from_store(
store: &RtrStore,
max_delta: u8,
timing: Timing,
) -> anyhow::Result<Option<Self>> {
let snapshot = store.get_snapshot()?;
let session_id = store.get_session_id()?;
let serial = store.get_serial()?;
let (snapshot, session_id, serial) = match (snapshot, session_id, serial) {
(Some(snapshot), Some(session_id), Some(serial)) => (snapshot, session_id, serial),
_ => {
tracing::warn!("RTR cache store incomplete: snapshot/session_id/serial missing");
return Ok(None);
}
};
if snapshot.is_empty() {
tracing::warn!("RTR cache store snapshot is empty, treat as unusable");
return Ok(None);
}
let mut cache = RtrCacheBuilder::new()
.session_id(session_id)
.max_delta(max_delta)
.timing(timing)
.serial(serial)
.snapshot(snapshot)
.build();
match store.get_delta_window()? {
Some((min_serial, _max_serial)) => {
let deltas = match store.load_deltas_since(min_serial.wrapping_sub(1)) {
Ok(deltas) => deltas,
Err(err) => {
tracing::warn!(
"RTR cache store delta recovery failed, treat store as unusable: {:?}",
err
);
return Ok(None);
}
};
for delta in deltas {
cache.push_delta(Arc::new(delta));
}
}
None => {
tracing::info!("RTR cache store has no delta window, restore snapshot only");
}
}
Ok(Some(cache))
}
fn next_serial(&mut self) -> u32 {
self.serial = self.serial.wrapping_add(1);
self.serial
}
fn push_delta(&mut self, delta: Arc<Delta>) {
if self.deltas.len() >= self.max_delta as usize {
self.deltas.pop_front();
}
self.deltas.push_back(delta);
}
fn replace_snapshot(&mut self, snapshot: Snapshot) {
self.snapshot = snapshot;
}
fn delta_window(&self) -> Option<(u32, u32)> {
let min = self.deltas.front().map(|d| d.serial());
let max = self.deltas.back().map(|d| d.serial());
match (min, max) {
(Some(min), Some(max)) => Some((min, max)),
_ => None,
}
}
fn store_sync(
&mut self,
store: &RtrStore,
snapshot: Snapshot,
serial: u32,
session_id: u16,
delta: Arc<Delta>,
) {
let window = self.delta_window();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = store.save_delta(&delta) {
tracing::error!("persist delta failed: {:?}", e);
}
if let Err(e) = store.save_snapshot_and_meta(&snapshot, session_id, serial) {
tracing::error!("persist snapshot/meta failed: {:?}", e);
}
if let Some((min_serial, max_serial)) = window {
if let Err(e) = store.set_delta_window(min_serial, max_serial) {
tracing::error!("persist delta window failed: {:?}", e);
}
}
});
}
// Update cache.
pub fn update(
&mut self,
new_payloads: Vec<Payload>,
store: &RtrStore,
) -> anyhow::Result<()> {
self.last_update_begin = DualTime::now();
let new_snapshot = Snapshot::from_payloads(new_payloads);
if self.snapshot.same_content(&new_snapshot) {
self.last_update_end = DualTime::now();
return Ok(());
}
let (announced, withdrawn) = self.snapshot.diff(&new_snapshot);
if announced.is_empty() && withdrawn.is_empty() {
self.last_update_end = DualTime::now();
return Ok(());
}
let new_serial = self.next_serial();
let delta = Arc::new(Delta::new(new_serial, announced, withdrawn));
self.push_delta(delta.clone());
self.replace_snapshot(new_snapshot.clone());
self.last_update_end = DualTime::now();
self.store_sync(store, new_snapshot, new_serial, self.session_id, delta);
Ok(())
}
pub fn session_id(&self) -> u16 {
self.session_id
}
pub fn snapshot(&self) -> Snapshot {
self.snapshot.clone()
}
pub fn serial(&self) -> u32 {
self.serial
}
pub fn timing(&self) -> Timing {
self.timing
}
pub fn retry_interval(&self) -> Duration {
DEFAULT_RETRY_INTERVAL
}
pub fn expire_interval(&self) -> Duration {
DEFAULT_EXPIRE_INTERVAL
}
pub fn current_snapshot(&self) -> (&Snapshot, u32, u16) {
(&self.snapshot, self.serial, self.session_id)
}
pub fn last_update_begin(&self) -> DualTime {
self.last_update_begin.clone()
}
pub fn last_update_end(&self) -> DualTime {
self.last_update_end.clone()
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
}
impl RtrCache {
pub fn get_deltas_since(
&self,
client_session: u16,
client_serial: u32,
) -> SerialResult {
if client_session != self.session_id {
return SerialResult::ResetRequired;
}
if client_serial == self.serial {
return SerialResult::UpToDate;
}
if self.deltas.is_empty() {
return SerialResult::ResetRequired;
}
let oldest_serial = self.deltas.front().unwrap().serial;
let newest_serial = self.deltas.back().unwrap().serial;
let min_supported = oldest_serial.wrapping_sub(1);
if client_serial < min_supported {
return SerialResult::ResetRequired;
}
if client_serial > self.serial {
return SerialResult::ResetRequired;
}
let mut result = Vec::new();
for delta in &self.deltas {
if delta.serial > client_serial {
result.push(delta.clone());
}
}
if let Some(first) = result.first() {
if first.serial != client_serial.wrapping_add(1) {
return SerialResult::ResetRequired;
}
} else {
return SerialResult::UpToDate;
}
let _ = newest_serial;
SerialResult::Deltas(result)
}
}
pub enum SerialResult {
/// Client is up to date.
UpToDate,
/// Return applicable deltas.
Deltas(Vec<Arc<Delta>>),
/// Delta window cannot cover; reset required.
ResetRequired,
}
//------------ RTR ordering -------------------------------------------------
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
enum PayloadPduType {
Ipv4Prefix = 4,
Ipv6Prefix = 6,
RouterKey = 9,
Aspa = 11,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum RouteOriginKey {
V4 {
addr: u32,
plen: u8,
mlen: u8,
asn: u32,
},
V6 {
addr: u128,
plen: u8,
mlen: u8,
asn: u32,
},
}
fn sort_payloads_for_rtr(payloads: &mut [Payload], announce: bool) {
payloads.sort_by(|a, b| compare_payload_for_rtr(a, b, announce));
}
fn compare_payload_for_rtr(a: &Payload, b: &Payload, announce: bool) -> Ordering {
let type_a = payload_pdu_type(a);
let type_b = payload_pdu_type(b);
match type_a.cmp(&type_b) {
Ordering::Equal => {}
other => return other,
}
match (a, b) {
(Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => {
compare_route_origin_for_rtr(a, b, announce)
}
(Payload::RouterKey(a), Payload::RouterKey(b)) => {
compare_router_key_for_rtr(a, b)
}
(Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b),
_ => Ordering::Equal,
}
}
fn payload_pdu_type(payload: &Payload) -> PayloadPduType {
match payload {
Payload::RouteOrigin(ro) => {
if route_origin_is_ipv4(ro) {
PayloadPduType::Ipv4Prefix
} else {
PayloadPduType::Ipv6Prefix
}
}
Payload::RouterKey(_) => PayloadPduType::RouterKey,
Payload::Aspa(_) => PayloadPduType::Aspa,
}
}
fn route_origin_is_ipv4(ro: &RouteOrigin) -> bool {
ro.prefix().address.is_ipv4()
}
fn route_origin_key(ro: &RouteOrigin) -> RouteOriginKey {
let prefix = ro.prefix();
let plen = prefix.prefix_length;
let mlen = ro.max_length();
let asn = ro.asn().into_u32();
match prefix.address {
IPAddress::V4(addr) => {
RouteOriginKey::V4 {
addr: u32::from(addr),
plen,
mlen,
asn,
}
}
IPAddress::V6(addr) => {
RouteOriginKey::V6 {
addr: u128::from(addr),
plen,
mlen,
asn,
}
}
}
}
fn compare_route_origin_for_rtr(
a: &RouteOrigin,
b: &RouteOrigin,
announce: bool,
) -> Ordering {
match (route_origin_key(a), route_origin_key(b)) {
(
RouteOriginKey::V4 {
addr: addr_a,
plen: plen_a,
mlen: mlen_a,
asn: asn_a,
},
RouteOriginKey::V4 {
addr: addr_b,
plen: plen_b,
mlen: mlen_b,
asn: asn_b,
},
) => {
if announce {
addr_b.cmp(&addr_a)
.then_with(|| mlen_b.cmp(&mlen_a))
.then_with(|| plen_b.cmp(&plen_a))
.then_with(|| asn_b.cmp(&asn_a))
} else {
addr_a.cmp(&addr_b)
.then_with(|| mlen_a.cmp(&mlen_b))
.then_with(|| plen_a.cmp(&plen_b))
.then_with(|| asn_a.cmp(&asn_b))
}
}
(
RouteOriginKey::V6 {
addr: addr_a,
plen: plen_a,
mlen: mlen_a,
asn: asn_a,
},
RouteOriginKey::V6 {
addr: addr_b,
plen: plen_b,
mlen: mlen_b,
asn: asn_b,
},
) => {
if announce {
addr_b.cmp(&addr_a)
.then_with(|| mlen_b.cmp(&mlen_a))
.then_with(|| plen_b.cmp(&plen_a))
.then_with(|| asn_b.cmp(&asn_a))
} else {
addr_a.cmp(&addr_b)
.then_with(|| mlen_a.cmp(&mlen_b))
.then_with(|| plen_a.cmp(&plen_b))
.then_with(|| asn_a.cmp(&asn_b))
}
}
_ => Ordering::Equal,
}
}
fn compare_router_key_for_rtr(a: &RouterKey, b: &RouterKey) -> Ordering {
a.ski()
.cmp(&b.ski())
.then_with(|| a.spki().len().cmp(&b.spki().len()))
.then_with(|| a.spki().cmp(b.spki()))
.then_with(|| a.asn().into_u32().cmp(&b.asn().into_u32()))
}
fn compare_aspa_for_rtr(a: &Aspa, b: &Aspa) -> Ordering {
a.customer_asn()
.into_u32()
.cmp(&b.customer_asn().into_u32())
}

632
src/rtr/cache/core.rs vendored Normal file
View File

@ -0,0 +1,632 @@
use std::collections::{BTreeMap, VecDeque};
use std::cmp::Ordering;
use std::sync::Arc;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
use crate::rtr::payload::{Payload, Timing};
use super::model::{Delta, DualTime, Snapshot};
use super::ordering::{change_key, ChangeKey};
const SERIAL_HALF_RANGE: u32 = 1 << 31;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq)]
pub enum CacheAvailability {
Ready,
NoDataAvailable,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct SessionIds {
ids: [u16; 3],
}
impl SessionIds {
pub fn from_array(ids: [u16; 3]) -> Self {
Self { ids }
}
pub fn random_distinct() -> Self {
let mut ids = [0u16; 3];
for idx in 0..ids.len() {
loop {
let candidate: u16 = rand::random();
if ids[..idx].iter().all(|existing| *existing != candidate) {
ids[idx] = candidate;
break;
}
}
}
Self { ids }
}
pub fn get(&self, version: u8) -> u16 {
self.ids[version_index(version)]
}
}
#[derive(Debug)]
pub struct RtrCache {
availability: CacheAvailability,
session_ids: SessionIds,
serial: u32,
snapshot: Snapshot,
deltas: VecDeque<Arc<Delta>>,
max_delta: u8,
timing: Timing,
last_update_begin: DualTime,
last_update_end: DualTime,
created_at: DualTime,
}
impl Default for RtrCache {
fn default() -> Self {
let now = DualTime::now();
Self {
availability: CacheAvailability::Ready,
session_ids: SessionIds::random_distinct(),
serial: 0,
snapshot: Snapshot::empty(),
deltas: VecDeque::with_capacity(100),
max_delta: 100,
timing: Timing::default(),
last_update_begin: now.clone(),
last_update_end: now.clone(),
created_at: now,
}
}
}
pub struct RtrCacheBuilder {
availability: Option<CacheAvailability>,
session_ids: Option<SessionIds>,
max_delta: Option<u8>,
timing: Option<Timing>,
serial: Option<u32>,
snapshot: Option<Snapshot>,
deltas: Option<VecDeque<Arc<Delta>>>,
created_at: Option<DualTime>,
}
impl RtrCacheBuilder {
pub fn new() -> Self {
Self {
availability: None,
session_ids: None,
max_delta: None,
timing: None,
serial: None,
snapshot: None,
deltas: None,
created_at: None,
}
}
pub fn session_ids(mut self, v: SessionIds) -> Self {
self.session_ids = Some(v);
self
}
pub fn availability(mut self, v: CacheAvailability) -> Self {
self.availability = Some(v);
self
}
pub fn max_delta(mut self, v: u8) -> Self {
self.max_delta = Some(v);
self
}
pub fn timing(mut self, v: Timing) -> Self {
self.timing = Some(v);
self
}
pub fn serial(mut self, v: u32) -> Self {
self.serial = Some(v);
self
}
pub fn snapshot(mut self, v: Snapshot) -> Self {
self.snapshot = Some(v);
self
}
pub fn deltas(mut self, v: VecDeque<Arc<Delta>>) -> Self {
self.deltas = Some(v);
self
}
pub fn created_at(mut self, v: DualTime) -> Self {
self.created_at = Some(v);
self
}
pub fn build(self) -> RtrCache {
let now = DualTime::now();
let max_delta = self.max_delta.unwrap_or(100);
let timing = self.timing.unwrap_or_default();
let snapshot = self.snapshot.unwrap_or_else(Snapshot::empty);
let deltas = self
.deltas
.unwrap_or_else(|| VecDeque::with_capacity(max_delta.into()));
let serial = self.serial.unwrap_or(0);
let created_at = self.created_at.unwrap_or_else(|| now.clone());
let availability = self.availability.unwrap_or(CacheAvailability::Ready);
let session_ids = self
.session_ids
.unwrap_or_else(SessionIds::random_distinct);
RtrCache {
availability,
session_ids,
serial,
snapshot,
deltas,
max_delta,
timing,
last_update_begin: now.clone(),
last_update_end: now,
created_at,
}
}
}
impl RtrCache {
fn set_unavailable(&mut self) {
warn!(
"RTR cache entering NoDataAvailable: old_serial={}, snapshot_empty={}, delta_count={}",
self.serial,
self.snapshot.is_empty(),
self.deltas.len()
);
self.availability = CacheAvailability::NoDataAvailable;
self.snapshot = Snapshot::empty();
self.deltas.clear();
}
fn reinitialize_from_snapshot(&mut self, snapshot: Snapshot) -> AppliedUpdate {
let old_serial = self.serial;
let old_session_ids = self.session_ids.clone();
self.availability = CacheAvailability::Ready;
self.session_ids = SessionIds::random_distinct();
self.serial = 1;
self.snapshot = snapshot.clone();
self.deltas.clear();
self.last_update_end = DualTime::now();
info!(
"RTR cache reinitialized from usable snapshot: old_serial={}, new_serial={}, old_session_ids={:?}, new_session_ids={:?}, payloads(route_origins={}, router_keys={}, aspas={})",
old_serial,
self.serial,
old_session_ids,
self.session_ids,
snapshot.origins().len(),
snapshot.router_keys().len(),
snapshot.aspas().len()
);
AppliedUpdate {
availability: self.availability,
snapshot,
serial: self.serial,
session_ids: self.session_ids.clone(),
delta: None,
delta_window: None,
clear_delta_window: true,
}
}
fn next_serial(&mut self) -> u32 {
let old = self.serial;
self.serial = self.serial.wrapping_add(1);
debug!(
"RTR cache advanced serial: old_serial={}, new_serial={}",
old,
self.serial
);
self.serial
}
fn push_delta(&mut self, delta: Arc<Delta>) {
let dropped_serial = if self.deltas.len() >= self.max_delta as usize {
self.deltas.front().map(|oldest| oldest.serial())
} else {
None
};
if self.deltas.len() >= self.max_delta as usize {
self.deltas.pop_front();
}
debug!(
"RTR cache pushing delta into window: delta_serial={}, announced={}, withdrawn={}, dropped_oldest_serial={:?}, window_size_before={}, max_delta={}",
delta.serial(),
delta.announced().len(),
delta.withdrawn().len(),
dropped_serial,
self.deltas.len(),
self.max_delta
);
self.deltas.push_back(delta);
}
fn replace_snapshot(&mut self, snapshot: Snapshot) {
self.snapshot = snapshot;
}
fn delta_window(&self) -> Option<(u32, u32)> {
let min = self.deltas.front().map(|d| d.serial());
let max = self.deltas.back().map(|d| d.serial());
match (min, max) {
(Some(min), Some(max)) => Some((min, max)),
_ => None,
}
}
pub(super) fn apply_update(&mut self, new_payloads: Vec<Payload>) -> Result<Option<AppliedUpdate>> {
self.last_update_begin = DualTime::now();
info!(
"RTR cache applying update: availability={:?}, current_serial={}, incoming_payloads={}",
self.availability,
self.serial,
new_payloads.len()
);
let new_snapshot = Snapshot::from_payloads(new_payloads);
debug!(
"RTR cache built new snapshot from update: route_origins={}, router_keys={}, aspas={}, snapshot_empty={}",
new_snapshot.origins().len(),
new_snapshot.router_keys().len(),
new_snapshot.aspas().len(),
new_snapshot.is_empty()
);
if new_snapshot.is_empty() {
let changed = self.availability != CacheAvailability::NoDataAvailable
|| !self.snapshot.is_empty()
|| !self.deltas.is_empty();
self.set_unavailable();
self.last_update_end = DualTime::now();
if !changed {
debug!("RTR cache update produced empty snapshot but cache was already unavailable; no state change");
return Ok(None);
}
info!(
"RTR cache update cleared usable data and marked cache unavailable: serial={}, session_ids={:?}",
self.serial,
self.session_ids
);
return Ok(Some(AppliedUpdate {
availability: self.availability,
snapshot: Snapshot::empty(),
serial: self.serial,
session_ids: self.session_ids.clone(),
delta: None,
delta_window: None,
clear_delta_window: true,
}));
}
if self.availability == CacheAvailability::NoDataAvailable {
info!("RTR cache recovered from NoDataAvailable with non-empty snapshot");
return Ok(Some(self.reinitialize_from_snapshot(new_snapshot)));
}
if self.snapshot.same_content(&new_snapshot) {
self.last_update_end = DualTime::now();
debug!(
"RTR cache update detected identical snapshot content: serial={}, session_ids={:?}",
self.serial,
self.session_ids
);
return Ok(None);
}
let (announced, withdrawn) = self.snapshot.diff(&new_snapshot);
debug!(
"RTR cache diff computed: announced={}, withdrawn={}, current_serial={}",
announced.len(),
withdrawn.len(),
self.serial
);
if announced.is_empty() && withdrawn.is_empty() {
self.last_update_end = DualTime::now();
debug!("RTR cache diff was empty after normalization; no update applied");
return Ok(None);
}
let new_serial = self.next_serial();
let delta = Arc::new(Delta::new(new_serial, announced, withdrawn));
if delta.is_empty() {
self.last_update_end = DualTime::now();
debug!(
"RTR cache delta collapsed to empty after dedup/order normalization: serial={}",
new_serial
);
return Ok(None);
}
self.push_delta(delta.clone());
self.replace_snapshot(new_snapshot.clone());
self.last_update_end = DualTime::now();
let delta_window = self.delta_window();
info!(
"RTR cache applied update: serial={}, announced={}, withdrawn={}, delta_window={:?}, snapshot(route_origins={}, router_keys={}, aspas={})",
new_serial,
delta.announced().len(),
delta.withdrawn().len(),
delta_window,
new_snapshot.origins().len(),
new_snapshot.router_keys().len(),
new_snapshot.aspas().len()
);
Ok(Some(AppliedUpdate {
availability: self.availability,
snapshot: new_snapshot,
serial: new_serial,
session_ids: self.session_ids.clone(),
delta: Some(delta),
delta_window,
clear_delta_window: false,
}))
}
pub fn is_data_available(&self) -> bool {
self.availability == CacheAvailability::Ready
}
pub fn availability(&self) -> CacheAvailability {
self.availability
}
pub fn session_id_for_version(&self, version: u8) -> u16 {
self.session_ids.get(version)
}
pub fn session_ids(&self) -> SessionIds {
self.session_ids.clone()
}
pub fn snapshot(&self) -> Snapshot {
self.snapshot.clone()
}
pub fn serial(&self) -> u32 {
self.serial
}
pub fn timing(&self) -> Timing {
self.timing
}
pub fn current_snapshot_with_session_ids(&self) -> (&Snapshot, u32, SessionIds) {
(&self.snapshot, self.serial, self.session_ids.clone())
}
pub fn last_update_begin(&self) -> DualTime {
self.last_update_begin.clone()
}
pub fn last_update_end(&self) -> DualTime {
self.last_update_end.clone()
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
pub fn get_deltas_since(&self, client_serial: u32) -> SerialResult {
if client_serial == self.serial {
debug!(
"RTR cache delta query is already up to date: client_serial={}, cache_serial={}",
client_serial,
self.serial
);
return SerialResult::UpToDate;
}
if matches!(
serial_cmp(client_serial, self.serial),
Some(Ordering::Greater) | None
) {
warn!(
"RTR cache delta query requires reset due to invalid/newer client serial: client_serial={}, cache_serial={}",
client_serial,
self.serial
);
return SerialResult::ResetRequired;
}
let deltas = match self.collect_deltas_since(client_serial) {
Some(deltas) => deltas,
None => {
warn!(
"RTR cache delta query requires reset because requested serial is outside delta window: client_serial={}, cache_serial={}, delta_window={:?}",
client_serial,
self.serial,
self.delta_window()
);
return SerialResult::ResetRequired;
}
};
if deltas.is_empty() {
debug!(
"RTR cache delta query resolved to no deltas: client_serial={}, cache_serial={}",
client_serial,
self.serial
);
return SerialResult::UpToDate;
}
let merged = self.merge_deltas_minimally(&deltas);
if merged.is_empty() {
debug!(
"RTR cache merged delta query to empty result: client_serial={}, cache_serial={}, source_deltas={}",
client_serial,
self.serial,
deltas.len()
);
SerialResult::UpToDate
} else {
info!(
"RTR cache serving delta query: client_serial={}, cache_serial={}, source_deltas={}, merged_announced={}, merged_withdrawn={}",
client_serial,
self.serial,
deltas.len(),
merged.announced().len(),
merged.withdrawn().len()
);
SerialResult::Delta(merged)
}
}
fn collect_deltas_since(&self, client_serial: u32) -> Option<Vec<Arc<Delta>>> {
if self.deltas.is_empty() {
return None;
}
let oldest_serial = self.deltas.front().unwrap().serial();
let min_supported = oldest_serial.wrapping_sub(1);
if matches!(
serial_cmp(client_serial, min_supported),
Some(Ordering::Less) | None
) {
return None;
}
let mut result = Vec::new();
for delta in &self.deltas {
if serial_gt(delta.serial(), client_serial) {
result.push(delta.clone());
}
}
if let Some(first) = result.first() {
if first.serial() != client_serial.wrapping_add(1) {
return None;
}
}
Some(result)
}
fn merge_deltas_minimally(&self, deltas: &[Arc<Delta>]) -> Delta {
let mut states = BTreeMap::<ChangeKey, LogicalState>::new();
for delta in deltas {
for payload in delta.withdrawn() {
let key = change_key(payload);
let state = states.entry(key).or_insert_with(LogicalState::new);
if state.before.is_none() && state.after.is_none() {
state.before = Some(payload.clone());
}
state.after = None;
}
for payload in delta.announced() {
let key = change_key(payload);
let state = states.entry(key).or_insert_with(LogicalState::new);
state.after = Some(payload.clone());
}
}
let mut announced = Vec::new();
let mut withdrawn = Vec::new();
for (_key, state) in states {
match (state.before, state.after) {
(None, None) => {}
(None, Some(new_payload)) => {
announced.push(new_payload);
}
(Some(old_payload), None) => {
withdrawn.push(old_payload);
}
(Some(old_payload), Some(new_payload)) => {
if old_payload != new_payload {
if matches!(old_payload, Payload::Aspa(_))
&& matches!(new_payload, Payload::Aspa(_))
{
announced.push(new_payload);
} else {
withdrawn.push(old_payload);
announced.push(new_payload);
}
}
}
}
}
Delta::new(self.serial, announced, withdrawn)
}
}
#[derive(Debug, Clone, Default)]
struct LogicalState {
before: Option<Payload>,
after: Option<Payload>,
}
impl LogicalState {
fn new() -> Self {
Self {
before: None,
after: None,
}
}
}
pub enum SerialResult {
UpToDate,
Delta(Delta),
ResetRequired,
}
pub(super) struct AppliedUpdate {
pub(super) availability: CacheAvailability,
pub(super) snapshot: Snapshot,
pub(super) serial: u32,
pub(super) session_ids: SessionIds,
pub(super) delta: Option<Arc<Delta>>,
pub(super) delta_window: Option<(u32, u32)>,
pub(super) clear_delta_window: bool,
}
fn serial_cmp(a: u32, b: u32) -> Option<Ordering> {
if a == b {
return Some(Ordering::Equal);
}
let diff = a.wrapping_sub(b);
if diff == SERIAL_HALF_RANGE {
None
} else if diff < SERIAL_HALF_RANGE {
Some(Ordering::Greater)
} else {
Some(Ordering::Less)
}
}
fn serial_gt(a: u32, b: u32) -> bool {
matches!(serial_cmp(a, b), Some(Ordering::Greater))
}
fn version_index(version: u8) -> usize {
match version {
0..=2 => version as usize,
_ => panic!("unsupported RTR protocol version: {}", version),
}
}

14
src/rtr/cache/mod.rs vendored Normal file
View File

@ -0,0 +1,14 @@
mod core;
mod model;
mod ordering;
mod store;
pub use core::{CacheAvailability, RtrCache, RtrCacheBuilder, SerialResult, SessionIds};
pub use model::{Delta, DualTime, Snapshot};
pub use ordering::{
OrderingViolation, validate_payload_updates_for_rtr, validate_payloads_for_rtr,
};
use std::sync::{Arc, RwLock};
pub type SharedRtrCache = Arc<RwLock<RtrCache>>;

392
src/rtr/cache/model.rs vendored Normal file
View File

@ -0,0 +1,392 @@
use std::collections::{BTreeMap, BTreeSet};
use std::time::{Duration, Instant};
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey};
use super::ordering::{compare_payload_update_for_rtr, sort_payloads_for_rtr};
#[derive(Debug, Clone)]
pub struct DualTime {
instant: Instant,
utc: DateTime<Utc>,
}
impl DualTime {
pub fn now() -> Self {
Self {
instant: Instant::now(),
utc: Utc::now(),
}
}
pub fn utc(&self) -> DateTime<Utc> {
self.utc
}
pub fn elapsed(&self) -> Duration {
self.instant.elapsed()
}
pub fn is_expired(&self, duration: Duration) -> bool {
self.elapsed() >= duration
}
pub fn reset(&mut self) {
self.instant = Instant::now();
self.utc = Utc::now();
}
}
impl Serialize for DualTime {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.utc.timestamp_millis().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for DualTime {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let millis = i64::deserialize(deserializer)?;
let naive = NaiveDateTime::from_timestamp_millis(millis)
.ok_or_else(|| serde::de::Error::custom("invalid timestamp"))?;
let utc = DateTime::<Utc>::from_utc(naive, Utc);
Ok(Self {
instant: Instant::now(),
utc,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Snapshot {
origins: BTreeSet<RouteOrigin>,
router_keys: BTreeSet<RouterKey>,
aspas: BTreeSet<Aspa>,
created_at: DualTime,
origins_hash: [u8; 32],
router_keys_hash: [u8; 32],
aspas_hash: [u8; 32],
snapshot_hash: [u8; 32],
}
impl Snapshot {
pub fn new(
origins: BTreeSet<RouteOrigin>,
router_keys: BTreeSet<RouterKey>,
aspas: BTreeSet<Aspa>,
) -> Self {
let mut snapshot = Snapshot {
origins,
router_keys,
aspas: normalize_aspas(aspas),
created_at: DualTime::now(),
origins_hash: [0u8; 32],
router_keys_hash: [0u8; 32],
aspas_hash: [0u8; 32],
snapshot_hash: [0u8; 32],
};
snapshot.recompute_hashes();
snapshot
}
pub fn empty() -> Self {
Self::new(BTreeSet::new(), BTreeSet::new(), BTreeSet::new())
}
pub fn from_payloads(payloads: Vec<Payload>) -> Self {
let mut origins = BTreeSet::new();
let mut router_keys = BTreeSet::new();
let mut aspas = Vec::new();
for p in payloads {
match p {
Payload::RouteOrigin(o) => {
origins.insert(o);
}
Payload::RouterKey(k) => {
router_keys.insert(k);
}
Payload::Aspa(a) => {
aspas.push(a);
}
}
}
Snapshot::new(origins, router_keys, normalize_aspas(aspas))
}
pub fn recompute_hashes(&mut self) {
self.origins_hash = self.compute_origins_hash();
self.router_keys_hash = self.compute_router_keys_hash();
self.aspas_hash = self.compute_aspas_hash();
self.snapshot_hash = self.compute_snapshot_hash();
}
fn compute_origins_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.origins.iter())
}
fn compute_router_keys_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.router_keys.iter())
}
fn compute_aspas_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.aspas.iter())
}
fn compute_snapshot_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"snapshot:v1");
hasher.update(self.origins_hash);
hasher.update(self.router_keys_hash);
hasher.update(self.aspas_hash);
hasher.finalize().into()
}
fn hash_ordered_iter<'a, T, I>(iter: I) -> [u8; 32]
where
T: Serialize + 'a,
I: IntoIterator<Item = &'a T>,
{
let mut hasher = Sha256::new();
hasher.update(b"set:v1");
for item in iter {
let encoded =
serde_json::to_vec(item).expect("serialize snapshot item for hashing failed");
let len = (encoded.len() as u32).to_be_bytes();
hasher.update(len);
hasher.update(encoded);
}
hasher.finalize().into()
}
pub fn diff(&self, new_snapshot: &Snapshot) -> (Vec<Payload>, Vec<Payload>) {
let mut announced = Vec::new();
let mut withdrawn = Vec::new();
if !self.same_origins(new_snapshot) {
for origin in new_snapshot.origins.difference(&self.origins) {
announced.push(Payload::RouteOrigin(origin.clone()));
}
for origin in self.origins.difference(&new_snapshot.origins) {
withdrawn.push(Payload::RouteOrigin(origin.clone()));
}
}
if !self.same_router_keys(new_snapshot) {
for key in new_snapshot.router_keys.difference(&self.router_keys) {
announced.push(Payload::RouterKey(key.clone()));
}
for key in self.router_keys.difference(&new_snapshot.router_keys) {
withdrawn.push(Payload::RouterKey(key.clone()));
}
}
if !self.same_aspas(new_snapshot) {
diff_aspas(&self.aspas, &new_snapshot.aspas, &mut announced, &mut withdrawn);
}
(announced, withdrawn)
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
pub fn payloads(&self) -> Vec<Payload> {
let mut v = Vec::with_capacity(
self.origins.len() + self.router_keys.len() + self.aspas.len(),
);
v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin));
v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey));
v.extend(self.aspas.iter().cloned().map(Payload::Aspa));
v
}
pub fn payloads_for_rtr(&self) -> Vec<Payload> {
let mut payloads = self.payloads();
sort_payloads_for_rtr(&mut payloads, true);
payloads
}
pub fn origins_hash(&self) -> [u8; 32] {
self.origins_hash
}
pub fn router_keys_hash(&self) -> [u8; 32] {
self.router_keys_hash
}
pub fn aspas_hash(&self) -> [u8; 32] {
self.aspas_hash
}
pub fn snapshot_hash(&self) -> [u8; 32] {
self.snapshot_hash
}
pub fn same_origins(&self, other: &Self) -> bool {
self.origins_hash == other.origins_hash
}
pub fn same_router_keys(&self, other: &Self) -> bool {
self.router_keys_hash == other.router_keys_hash
}
pub fn same_aspas(&self, other: &Self) -> bool {
self.aspas_hash == other.aspas_hash
}
pub fn same_content(&self, other: &Self) -> bool {
self.snapshot_hash == other.snapshot_hash
}
pub fn origins(&self) -> &BTreeSet<RouteOrigin> {
&self.origins
}
pub fn router_keys(&self) -> &BTreeSet<RouterKey> {
&self.router_keys
}
pub fn aspas(&self) -> &BTreeSet<Aspa> {
&self.aspas
}
pub fn is_empty(&self) -> bool {
self.origins.is_empty()
&& self.router_keys.is_empty()
&& self.aspas.is_empty()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Delta {
serial: u32,
announced: Vec<Payload>,
withdrawn: Vec<Payload>,
created_at: DualTime,
}
impl Delta {
pub fn new(serial: u32, mut announced: Vec<Payload>, mut withdrawn: Vec<Payload>) -> Self {
dedup_payloads(&mut announced);
dedup_payloads(&mut withdrawn);
sort_payloads_for_rtr(&mut announced, true);
sort_payloads_for_rtr(&mut withdrawn, false);
Delta {
serial,
announced,
withdrawn,
created_at: DualTime::now(),
}
}
pub fn serial(&self) -> u32 {
self.serial
}
pub fn announced(&self) -> &[Payload] {
&self.announced
}
pub fn withdrawn(&self) -> &[Payload] {
&self.withdrawn
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
pub fn is_empty(&self) -> bool {
self.announced.is_empty() && self.withdrawn.is_empty()
}
pub fn payloads_for_rtr(&self) -> Vec<(bool, Payload)> {
let mut updates = Vec::with_capacity(self.announced.len() + self.withdrawn.len());
updates.extend(self.announced.iter().cloned().map(|p| (true, p)));
updates.extend(self.withdrawn.iter().cloned().map(|p| (false, p)));
updates.sort_by(|(a_upd, a_payload), (b_upd, b_payload)| {
compare_payload_update_for_rtr(a_payload, *a_upd, b_payload, *b_upd)
});
updates
}
}
fn dedup_payloads(payloads: &mut Vec<Payload>) {
let mut seen = BTreeSet::new();
payloads.retain(|p| seen.insert(p.clone()));
}
fn normalize_aspas<I>(aspas: I) -> BTreeSet<Aspa>
where
I: IntoIterator<Item = Aspa>,
{
let mut by_customer = BTreeMap::<u32, BTreeSet<_>>::new();
for aspa in aspas {
let providers = by_customer
.entry(aspa.customer_asn().into_u32())
.or_default();
providers.extend(aspa.provider_asns().iter().copied());
}
by_customer
.into_iter()
.map(|(customer_asn, providers)| {
Aspa::new(customer_asn.into(), providers.into_iter().collect())
})
.collect()
}
fn diff_aspas(
current: &BTreeSet<Aspa>,
next: &BTreeSet<Aspa>,
announced: &mut Vec<Payload>,
withdrawn: &mut Vec<Payload>,
) {
let current = current
.iter()
.map(|aspa| (aspa.customer_asn().into_u32(), aspa))
.collect::<BTreeMap<_, _>>();
let next = next
.iter()
.map(|aspa| (aspa.customer_asn().into_u32(), aspa))
.collect::<BTreeMap<_, _>>();
let customers = current
.keys()
.chain(next.keys())
.copied()
.collect::<BTreeSet<_>>();
for customer in customers {
match (current.get(&customer), next.get(&customer)) {
(None, Some(new_aspa)) => announced.push(Payload::Aspa((*new_aspa).clone())),
(Some(old_aspa), None) => withdrawn.push(Payload::Aspa((*old_aspa).clone())),
(Some(old_aspa), Some(new_aspa)) if old_aspa != new_aspa => {
announced.push(Payload::Aspa((*new_aspa).clone()));
}
_ => {}
}
}
}

311
src/rtr/cache/ordering.rs vendored Normal file
View File

@ -0,0 +1,311 @@
use std::cmp::Ordering;
use std::fmt;
use crate::data_model::resources::ip_resources::IPAddress;
use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Ski};
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct OrderingViolation {
index: usize,
left: String,
right: String,
}
impl OrderingViolation {
fn new(index: usize, left: &Payload, right: &Payload) -> Self {
Self {
index,
left: payload_brief(left),
right: payload_brief(right),
}
}
fn new_update(index: usize, left: (bool, &Payload), right: (bool, &Payload)) -> Self {
Self {
index,
left: payload_update_brief(left.0, left.1),
right: payload_update_brief(right.0, right.1),
}
}
}
impl fmt::Display for OrderingViolation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"RTR payload ordering violation at positions {} and {}: {} should not appear before {}",
self.index,
self.index + 1,
self.left,
self.right
)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) enum RouterKeyKey {
Key { ski: Ski, spki: Vec<u8>, asn: u32 },
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) enum ChangeKey {
RouteOrigin(RouteOriginKey),
RouterKey(RouterKeyKey),
AspaCustomer(u32),
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
enum PayloadPduType {
Ipv4Prefix = 4,
Ipv6Prefix = 6,
RouterKey = 9,
Aspa = 11,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) enum RouteOriginKey {
V4 { addr: u32, plen: u8, mlen: u8, asn: u32 },
V6 { addr: u128, plen: u8, mlen: u8, asn: u32 },
}
pub(crate) fn change_key(payload: &Payload) -> ChangeKey {
match payload {
Payload::RouteOrigin(ro) => ChangeKey::RouteOrigin(route_origin_key(ro)),
Payload::RouterKey(rk) => ChangeKey::RouterKey(router_key_key(rk)),
Payload::Aspa(aspa) => ChangeKey::AspaCustomer(aspa.customer_asn().into_u32()),
}
}
pub(crate) fn compare_payload_update_for_rtr(
a_payload: &Payload,
a_announce: bool,
b_payload: &Payload,
b_announce: bool,
) -> Ordering {
let type_a = payload_pdu_type(a_payload);
let type_b = payload_pdu_type(b_payload);
match type_a.cmp(&type_b) {
Ordering::Equal => {}
other => return other,
}
match b_announce.cmp(&a_announce) {
Ordering::Equal => {}
other => return other,
}
match (a_payload, b_payload) {
(Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => {
compare_route_origin_for_rtr(a, b, a_announce)
}
(Payload::RouterKey(a), Payload::RouterKey(b)) => compare_router_key_for_rtr(a, b),
(Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b),
_ => Ordering::Equal,
}
}
pub(crate) fn sort_payloads_for_rtr(payloads: &mut [Payload], announce: bool) {
payloads.sort_by(|a, b| compare_payload_for_rtr(a, b, announce));
}
pub fn validate_payloads_for_rtr(
payloads: &[Payload],
announce: bool,
) -> Result<(), OrderingViolation> {
for (index, pair) in payloads.windows(2).enumerate() {
if compare_payload_for_rtr(&pair[0], &pair[1], announce) == Ordering::Greater {
return Err(OrderingViolation::new(index, &pair[0], &pair[1]));
}
}
Ok(())
}
pub fn validate_payload_updates_for_rtr(
updates: &[(bool, Payload)],
) -> Result<(), OrderingViolation> {
for (index, pair) in updates.windows(2).enumerate() {
if compare_payload_update_for_rtr(&pair[0].1, pair[0].0, &pair[1].1, pair[1].0)
== Ordering::Greater
{
return Err(OrderingViolation::new_update(
index,
(pair[0].0, &pair[0].1),
(pair[1].0, &pair[1].1),
));
}
}
Ok(())
}
fn router_key_key(rk: &RouterKey) -> RouterKeyKey {
RouterKeyKey::Key {
ski: rk.ski(),
spki: rk.spki().as_ref().to_vec(),
asn: rk.asn().into_u32(),
}
}
fn compare_payload_for_rtr(a: &Payload, b: &Payload, announce: bool) -> Ordering {
let type_a = payload_pdu_type(a);
let type_b = payload_pdu_type(b);
match type_a.cmp(&type_b) {
Ordering::Equal => {}
other => return other,
}
match (a, b) {
(Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => {
compare_route_origin_for_rtr(a, b, announce)
}
(Payload::RouterKey(a), Payload::RouterKey(b)) => compare_router_key_for_rtr(a, b),
(Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b),
_ => Ordering::Equal,
}
}
fn payload_pdu_type(payload: &Payload) -> PayloadPduType {
match payload {
Payload::RouteOrigin(ro) => {
if route_origin_is_ipv4(ro) {
PayloadPduType::Ipv4Prefix
} else {
PayloadPduType::Ipv6Prefix
}
}
Payload::RouterKey(_) => PayloadPduType::RouterKey,
Payload::Aspa(_) => PayloadPduType::Aspa,
}
}
fn route_origin_is_ipv4(ro: &RouteOrigin) -> bool {
ro.prefix().address.is_ipv4()
}
fn route_origin_key(ro: &RouteOrigin) -> RouteOriginKey {
let prefix = ro.prefix();
let plen = prefix.prefix_length;
let mlen = ro.max_length();
let asn = ro.asn().into_u32();
match prefix.address {
IPAddress::V4(addr) => RouteOriginKey::V4 {
addr: u32::from(addr),
plen,
mlen,
asn,
},
IPAddress::V6(addr) => RouteOriginKey::V6 {
addr: u128::from(addr),
plen,
mlen,
asn,
},
}
}
fn compare_route_origin_for_rtr(a: &RouteOrigin, b: &RouteOrigin, announce: bool) -> Ordering {
match (route_origin_key(a), route_origin_key(b)) {
(
RouteOriginKey::V4 {
addr: addr_a,
plen: plen_a,
mlen: mlen_a,
asn: asn_a,
},
RouteOriginKey::V4 {
addr: addr_b,
plen: plen_b,
mlen: mlen_b,
asn: asn_b,
},
) => {
if announce {
addr_b
.cmp(&addr_a)
.then_with(|| mlen_b.cmp(&mlen_a))
.then_with(|| plen_b.cmp(&plen_a))
.then_with(|| asn_b.cmp(&asn_a))
} else {
addr_a
.cmp(&addr_b)
.then_with(|| mlen_a.cmp(&mlen_b))
.then_with(|| plen_a.cmp(&plen_b))
.then_with(|| asn_a.cmp(&asn_b))
}
}
(
RouteOriginKey::V6 {
addr: addr_a,
plen: plen_a,
mlen: mlen_a,
asn: asn_a,
},
RouteOriginKey::V6 {
addr: addr_b,
plen: plen_b,
mlen: mlen_b,
asn: asn_b,
},
) => {
if announce {
addr_b
.cmp(&addr_a)
.then_with(|| mlen_b.cmp(&mlen_a))
.then_with(|| plen_b.cmp(&plen_a))
.then_with(|| asn_b.cmp(&asn_a))
} else {
addr_a
.cmp(&addr_b)
.then_with(|| mlen_a.cmp(&mlen_b))
.then_with(|| plen_a.cmp(&plen_b))
.then_with(|| asn_a.cmp(&asn_b))
}
}
_ => Ordering::Equal,
}
}
fn compare_router_key_for_rtr(a: &RouterKey, b: &RouterKey) -> Ordering {
a.ski()
.cmp(&b.ski())
.then_with(|| a.spki().len().cmp(&b.spki().len()))
.then_with(|| a.spki().cmp(b.spki()))
.then_with(|| a.asn().into_u32().cmp(&b.asn().into_u32()))
}
fn compare_aspa_for_rtr(a: &Aspa, b: &Aspa) -> Ordering {
a.customer_asn()
.into_u32()
.cmp(&b.customer_asn().into_u32())
}
fn payload_brief(payload: &Payload) -> String {
match payload {
Payload::RouteOrigin(origin) => format!(
"{} prefix {:?}/{} max={} asn={}",
if route_origin_is_ipv4(origin) { "IPv4" } else { "IPv6" },
origin.prefix().address,
origin.prefix().prefix_length,
origin.max_length(),
origin.asn().into_u32()
),
Payload::RouterKey(key) => format!(
"RouterKey ski={:02x?} asn={}",
key.ski().as_ref(),
key.asn().into_u32()
),
Payload::Aspa(aspa) => format!("ASPA customer_asn={}", aspa.customer_asn().into_u32()),
}
}
fn payload_update_brief(announce: bool, payload: &Payload) -> String {
format!(
"{} {}",
if announce { "announce" } else { "withdraw" },
payload_brief(payload)
)
}

195
src/rtr/cache/store.rs vendored Normal file
View File

@ -0,0 +1,195 @@
use std::collections::VecDeque;
use std::sync::Arc;
use anyhow::Result;
use crate::rtr::payload::{Payload, Timing};
use crate::rtr::store::RtrStore;
use super::core::{AppliedUpdate, CacheAvailability, RtrCache, RtrCacheBuilder, SessionIds};
use super::model::Snapshot;
impl RtrCache {
pub fn init(
self,
store: &RtrStore,
max_delta: u8,
timing: Timing,
file_loader: impl Fn() -> Result<Vec<Payload>>,
) -> Result<Self> {
if let Some(cache) = try_restore_from_store(store, max_delta, timing)? {
tracing::info!(
"RTR cache restored from store: availability={:?}, session_ids={:?}, serial={}, snapshot(route_origins={}, router_keys={}, aspas={})",
cache.availability(),
cache.session_ids(),
cache.serial(),
cache.snapshot().origins().len(),
cache.snapshot().router_keys().len(),
cache.snapshot().aspas().len()
);
return Ok(cache);
}
tracing::warn!("RTR cache store unavailable or invalid, fallback to file loader");
let payloads = file_loader()?;
let session_ids = SessionIds::random_distinct();
let snapshot = Snapshot::from_payloads(payloads);
let availability = if snapshot.is_empty() {
CacheAvailability::NoDataAvailable
} else {
CacheAvailability::Ready
};
let serial = if snapshot.is_empty() { 0 } else { 1 };
if snapshot.is_empty() {
tracing::warn!(
"RTR cache initialized without usable data: session_ids={:?}, serial={}",
session_ids,
serial
);
} else {
tracing::info!(
"RTR cache initialized from file loader: session_ids={:?}, serial={}",
session_ids,
serial
);
}
let snapshot_for_store = snapshot.clone();
let session_ids_for_store = session_ids.clone();
tokio::spawn({
let store = store.clone();
async move {
if let Err(e) = store.save_cache_state(
availability,
&snapshot_for_store,
&session_ids_for_store,
serial,
None,
None,
true,
) {
tracing::error!("persist cache state failed: {:?}", e);
}
}
});
Ok(RtrCacheBuilder::new()
.availability(availability)
.session_ids(session_ids)
.max_delta(max_delta)
.timing(timing)
.serial(serial)
.snapshot(snapshot)
.build())
}
pub fn update(&mut self, new_payloads: Vec<Payload>, store: &RtrStore) -> Result<()> {
if let Some(update) = self.apply_update(new_payloads)? {
spawn_store_sync(store, update);
}
Ok(())
}
}
fn try_restore_from_store(store: &RtrStore, max_delta: u8, timing: Timing) -> Result<Option<RtrCache>> {
let snapshot = store.get_snapshot()?;
let session_ids = store.get_session_ids()?;
let serial = store.get_serial()?;
let availability = store.get_availability()?;
let (snapshot, session_ids, serial) = match (snapshot, session_ids, serial) {
(Some(snapshot), Some(session_ids), Some(serial)) => (snapshot, session_ids, serial),
_ => {
tracing::warn!("RTR cache store incomplete: snapshot/session_ids/serial missing");
return Ok(None);
}
};
let availability = availability.unwrap_or_else(|| {
tracing::warn!("RTR cache store missing availability metadata, defaulting to Ready");
CacheAvailability::Ready
});
let deltas = if availability == CacheAvailability::NoDataAvailable {
tracing::warn!("RTR cache store restored in no-data-available state");
VecDeque::with_capacity(max_delta.into())
} else {
match store.get_delta_window()? {
Some((min_serial, max_serial)) => {
match store.load_delta_window(min_serial, max_serial) {
Ok(deltas) => deltas.into_iter().map(Arc::new).collect(),
Err(err) => {
tracing::warn!(
"RTR cache store delta recovery failed, treat store as unusable: {:?}",
err
);
return Ok(None);
}
}
}
None => {
tracing::info!("RTR cache store has no delta window, restore snapshot only");
VecDeque::with_capacity(max_delta.into())
}
}
};
Ok(Some(
RtrCacheBuilder::new()
.availability(availability)
.session_ids(session_ids)
.max_delta(max_delta)
.timing(timing)
.serial(serial)
.snapshot(snapshot)
.deltas(deltas)
.build(),
))
}
fn spawn_store_sync(store: &RtrStore, update: AppliedUpdate) {
let AppliedUpdate {
availability,
snapshot,
serial,
session_ids,
delta,
delta_window,
clear_delta_window,
} = update;
tokio::spawn({
let store = store.clone();
async move {
tracing::debug!(
"persisting RTR cache state: availability={:?}, serial={}, session_ids={:?}, delta_present={}, delta_window={:?}, clear_delta_window={}, snapshot(route_origins={}, router_keys={}, aspas={})",
availability,
serial,
session_ids,
delta.is_some(),
delta_window,
clear_delta_window,
snapshot.origins().len(),
snapshot.router_keys().len(),
snapshot.aspas().len()
);
if let Err(e) = store.save_cache_state(
availability,
&snapshot,
&session_ids,
serial,
delta.as_deref(),
delta_window,
clear_delta_window,
) {
tracing::error!("persist cache state failed: {:?}", e);
} else {
tracing::debug!("persist RTR cache state completed: serial={}", serial);
}
}
});
}

View File

@ -1,9 +1,9 @@
pub mod pdu; pub mod pdu;
pub mod cache; pub mod cache;
pub mod payload; pub mod payload;
pub mod store_db; pub mod store;
pub mod session; pub mod session;
pub mod error_type; pub mod error_type;
pub mod state; pub mod state;
pub mod server; pub mod server;
pub mod loader; pub mod loader;

View File

@ -1,4 +1,5 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::io;
use std::time::Duration; use std::time::Duration;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::data_model::resources::as_resources::Asn; use crate::data_model::resources::as_resources::Asn;
@ -16,6 +17,11 @@ enum PayloadPduType {
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct Ski([u8; 20]); pub struct Ski([u8; 20]);
impl AsRef<[u8]> for Ski {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct RouteOrigin { pub struct RouteOrigin {
@ -100,10 +106,35 @@ impl Aspa {
pub fn provider_asns(&self) -> &[Asn] { pub fn provider_asns(&self) -> &[Asn] {
&self.provider_asns &self.provider_asns
} }
pub fn validate_announcement(&self) -> Result<(), io::Error> {
if self.customer_asn.into_u32() == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA customer ASN must not be AS0",
));
}
if self.provider_asns.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA announcement must contain at least one provider ASN",
));
}
if self.provider_asns.iter().any(|asn| asn.into_u32() == 0) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA provider list must not contain AS0",
));
}
Ok(())
}
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum Payload { pub enum Payload {
/// A route origin. /// A route origin.
@ -131,10 +162,71 @@ pub struct Timing {
} }
impl Timing { impl Timing {
pub const MIN_REFRESH: u32 = 1;
pub const MAX_REFRESH: u32 = 86_400;
pub const MIN_RETRY: u32 = 1;
pub const MAX_RETRY: u32 = 7_200;
pub const MIN_EXPIRE: u32 = 600;
pub const MAX_EXPIRE: u32 = 172_800;
pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self { pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self {
Self { refresh, retry, expire } Self { refresh, retry, expire }
} }
pub fn validate(self) -> Result<(), io::Error> {
if !(Self::MIN_REFRESH..=Self::MAX_REFRESH).contains(&self.refresh) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"refresh interval {} out of range {}..={}",
self.refresh, Self::MIN_REFRESH, Self::MAX_REFRESH
),
));
}
if !(Self::MIN_RETRY..=Self::MAX_RETRY).contains(&self.retry) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"retry interval {} out of range {}..={}",
self.retry, Self::MIN_RETRY, Self::MAX_RETRY
),
));
}
if !(Self::MIN_EXPIRE..=Self::MAX_EXPIRE).contains(&self.expire) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"expire interval {} out of range {}..={}",
self.expire, Self::MIN_EXPIRE, Self::MAX_EXPIRE
),
));
}
if self.expire <= self.refresh {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"expire interval {} must be greater than refresh interval {}",
self.expire, self.refresh
),
));
}
if self.expire <= self.retry {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"expire interval {} must be greater than retry interval {}",
self.expire, self.retry
),
));
}
Ok(())
}
pub fn refresh(self) -> Duration { pub fn refresh(self) -> Duration {
Duration::from_secs(u64::from(self.refresh)) Duration::from_secs(u64::from(self.refresh))
} }
@ -157,4 +249,4 @@ impl Default for Timing {
expire: 7200, expire: 7200,
} }
} }
} }

View File

@ -2,9 +2,9 @@ use std::{cmp, mem};
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc; use std::sync::Arc;
use crate::data_model::resources::as_resources::Asn; use crate::data_model::resources::as_resources::Asn;
use crate::rtr::error_type::ErrorCode;
use crate::rtr::payload::{Ski, Timing}; use crate::rtr::payload::{Ski, Timing};
use std::io; use std::io;
use std::io::Write;
use tokio::io::{AsyncWrite}; use tokio::io::{AsyncWrite};
use anyhow::Result; use anyhow::Result;
@ -12,7 +12,6 @@ use std::slice;
use anyhow::bail; use anyhow::bail;
use serde::Serialize; use serde::Serialize;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
pub const HEADER_LEN: usize = 8; pub const HEADER_LEN: usize = 8;
pub const MAX_PDU_LEN: u32 = 65535; pub const MAX_PDU_LEN: u32 = 65535;
@ -208,10 +207,15 @@ impl Header {
} }
} }
pub async fn read<S: AsyncRead + Unpin>(sock: &mut S) -> Result<Self, io::Error> { pub async fn read_raw<S: AsyncRead + Unpin>(
sock: &mut S
) -> Result<[u8; HEADER_LEN], io::Error> {
let mut buf = [0u8; HEADER_LEN]; let mut buf = [0u8; HEADER_LEN];
sock.read_exact(&mut buf).await?; sock.read_exact(&mut buf).await?;
Ok(buf)
}
pub fn from_raw(buf: [u8; HEADER_LEN]) -> Result<Self, io::Error> {
let version = buf[0]; let version = buf[0];
let pdu = buf[1]; let pdu = buf[1];
let session_id = u16::from_be_bytes([buf[2], buf[3]]); let session_id = u16::from_be_bytes([buf[2], buf[3]]);
@ -239,6 +243,10 @@ impl Header {
}) })
} }
pub async fn read<S: AsyncRead + Unpin>(sock: &mut S) -> Result<Self, io::Error> {
Self::from_raw(Self::read_raw(sock).await?)
}
pub fn version(self) -> u8{self.version} pub fn version(self) -> u8{self.version}
pub fn pdu(self) -> u8{self.pdu} pub fn pdu(self) -> u8{self.pdu}
@ -256,6 +264,10 @@ impl Header {
}) })
} }
pub fn error_code(self) -> u16 {
debug_assert_eq!(self.pdu(), ErrorReport::PDU);
self.session_id()
}
} }
@ -281,7 +293,7 @@ impl HeaderWithFlags {
length: length.to_be(), length: length.to_be(),
} }
} }
pub async fn read(sock: &mut TcpStream) -> Result<Self> { pub async fn read<S: AsyncRead + Unpin>(sock: &mut S) -> Result<Self> {
let mut buf = [0u8; HEADER_LEN]; let mut buf = [0u8; HEADER_LEN];
// 1. 精确读取 8 字节 // 1. 精确读取 8 字节
@ -311,7 +323,7 @@ impl HeaderWithFlags {
pdu, pdu,
flags, flags,
zero, zero,
length, length: length.to_be(),
}) })
} }
@ -344,7 +356,7 @@ impl SerialNotify {
} }
pub fn serial_number(self) -> u32 { pub fn serial_number(self) -> u32 {
self.serial_number u32::from_be(self.serial_number)
} }
} }
@ -370,7 +382,7 @@ impl SerialQuery {
} }
pub fn serial_number(self) -> u32 { pub fn serial_number(self) -> u32 {
self.serial_number u32::from_be(self.serial_number)
} }
} }
@ -538,12 +550,17 @@ pub enum EndOfData {
} }
impl EndOfData { impl EndOfData {
pub fn new(version: u8, session_id: u16, serial_number: u32, timing: Timing) -> Self { pub fn new(
version: u8,
session_id: u16,
serial_number: u32,
timing: Timing,
) -> Result<Self, io::Error> {
if version == 0 { if version == 0 {
EndOfData::V0(EndOfDataV0::new(version, session_id, serial_number)) Ok(EndOfData::V0(EndOfDataV0::new(version, session_id, serial_number)))
} }
else { else {
EndOfData::V1(EndOfDataV1::new(version, session_id, serial_number, timing)) Ok(EndOfData::V1(EndOfDataV1::new(version, session_id, serial_number, timing)?))
} }
} }
@ -588,14 +605,37 @@ pub struct EndOfDataV1 {
impl EndOfDataV1 { impl EndOfDataV1 {
pub const PDU: u8 = 7; pub const PDU: u8 = 7;
pub fn new(version: u8, session_id: u16, serial_number: u32, timing: Timing) -> Self { pub fn version(&self) -> u8 {
EndOfDataV1 { self.header.version()
}
pub fn session_id(&self) -> u16 {
self.header.session_id()
}
pub fn pdu(&self) -> u8 {
self.header.pdu()
}
pub fn size() -> u32 {
mem::size_of::<Self>() as u32
}
pub fn new(
version: u8,
session_id: u16,
serial_number: u32,
timing: Timing,
) -> Result<Self, io::Error> {
timing.validate()?;
Ok(EndOfDataV1 {
header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V1_LEN), header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V1_LEN),
serial_number: serial_number.to_be(), serial_number: serial_number.to_be(),
refresh_interval: timing.refresh.to_be(), refresh_interval: timing.refresh.to_be(),
retry_interval: timing.retry.to_be(), retry_interval: timing.retry.to_be(),
expire_interval: timing.expire.to_be(), expire_interval: timing.expire.to_be(),
} })
} }
pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)}
@ -607,8 +647,50 @@ impl EndOfDataV1 {
expire: u32::from_be(self.expire_interval), expire: u32::from_be(self.expire_interval),
} }
} }
fn validate(&self) -> Result<(), io::Error> {
self.timing().validate()
}
pub async fn read<Sock: AsyncRead + Unpin>(
sock: &mut Sock
) -> Result<Self, io::Error> {
let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU type mismatch when expecting EndOfDataV1",
))
}
if res.header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid length for EndOfDataV1",
))
}
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
res.validate()?;
Ok(res)
}
pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: Header, sock: &mut Sock
) -> Result<Self, io::Error> {
if header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid length for EndOfDataV1 PDU",
))
}
let mut res = Self::default();
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
res.header = header;
res.validate()?;
Ok(res)
}
} }
concrete!(EndOfDataV1); common!(EndOfDataV1);
// Cache Reset // Cache Reset
#[repr(C, packed)] #[repr(C, packed)]
@ -640,6 +722,7 @@ pub struct ErrorReport {
impl ErrorReport { impl ErrorReport {
/// The PDU type of an error PDU. /// The PDU type of an error PDU.
pub const PDU: u8 = 10; pub const PDU: u8 = 10;
const FIXED_PART_LEN: usize = HEADER_LEN + 2 * mem::size_of::<u32>();
/// Creates a new error PDU from components. /// Creates a new error PDU from components.
pub fn new( pub fn new(
@ -650,12 +733,12 @@ impl ErrorReport {
) -> Self { ) -> Self {
let pdu = pdu.as_ref(); let pdu = pdu.as_ref();
let text = text.as_ref(); let text = text.as_ref();
let max_payload_len = MAX_PDU_LEN as usize - Self::FIXED_PART_LEN;
let pdu_len = cmp::min(pdu.len(), max_payload_len);
let text_room = max_payload_len - pdu_len;
let text_len = cmp::min(text.len(), text_room);
let size = let size = Self::FIXED_PART_LEN + pdu_len + text_len;
mem::size_of::<Header>()
+ 2 * mem::size_of::<u32>()
+ pdu.len() + text.len()
;
let header = Header::new( let header = Header::new(
version, 10, error_code, u32::try_from(size).unwrap() version, 10, error_code, u32::try_from(size).unwrap()
); );
@ -663,37 +746,92 @@ impl ErrorReport {
let mut octets = Vec::with_capacity(size); let mut octets = Vec::with_capacity(size);
octets.extend_from_slice(header.as_ref()); octets.extend_from_slice(header.as_ref());
octets.extend_from_slice( octets.extend_from_slice(
u32::try_from(pdu.len()).unwrap().to_be_bytes().as_ref() u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref()
); );
octets.extend_from_slice(pdu); octets.extend_from_slice(&pdu[..pdu_len]);
octets.extend_from_slice( octets.extend_from_slice(
u32::try_from(text.len()).unwrap().to_be_bytes().as_ref() u32::try_from(text_len).unwrap().to_be_bytes().as_ref()
); );
octets.extend_from_slice(text); octets.extend_from_slice(&text[..text_len]);
ErrorReport { octets } ErrorReport { octets }
} }
pub async fn read<Sock: AsyncRead + Unpin>(
sock: &mut Sock
) -> Result<Self, io::Error> {
let header = Header::read(sock).await?;
if header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU type mismatch when expecting ErrorReport",
));
}
Self::read_payload(header, sock).await
}
pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: Header,
sock: &mut Sock,
) -> Result<Self, io::Error> {
let total_len = header.pdu_len()?;
let Some(payload_len) = total_len.checked_sub(mem::size_of::<Header>()) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU size smaller than header size",
));
};
let mut octets = Vec::with_capacity(total_len);
octets.extend_from_slice(header.as_ref());
octets.resize(total_len, 0);
sock.read_exact(&mut octets[mem::size_of::<Header>()..]).await?;
let res = ErrorReport { octets };
res.validate()?;
debug_assert_eq!(payload_len + mem::size_of::<Header>(), res.octets.len());
Ok(res)
}
pub fn version(&self) -> u8 {
self.header().version()
}
pub fn error_code(&self) -> Result<ErrorCode, u16> {
ErrorCode::try_from(self.header().error_code()).map_err(|_| self.header().error_code())
}
pub fn erroneous_pdu(&self) -> &[u8] {
&self.octets[self.erroneous_pdu_range()]
}
pub fn text(&self) -> &[u8] {
&self.octets[self.text_range()]
}
/// Skips over the payload of the error PDU. /// Skips over the payload of the error PDU.
pub async fn skip_payload<Sock: AsyncRead + Unpin>( pub async fn skip_payload<Sock: AsyncRead + Unpin>(
header: Header, sock: &mut Sock header: Header, sock: &mut Sock
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
let Some(mut remaining) = header.pdu_len()?.checked_sub( let Some(mut remaining) = header.pdu_len()?.checked_sub(mem::size_of::<Header>()) else {
mem::size_of::<Header>()
) else {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"PDU size smaller than header size", "PDU size smaller than header size",
)) ));
}; };
let mut buf = [0u8; 1024]; let mut buf = [0u8; 1024];
while remaining > 0 { while remaining > 0 {
let read_len = cmp::min(remaining, mem::size_of_val(&buf)); let read_len = cmp::min(remaining, mem::size_of_val(&buf));
let read = sock.read( let read = sock.read(&mut buf[..read_len]).await?;
// Safety: We limited the length to the buffer size.
unsafe { buf.get_unchecked_mut(..read_len) } if read == 0 {
).await?; return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF while skipping ErrorReport payload",
));
}
remaining -= read; remaining -= read;
} }
Ok(()) Ok(())
@ -705,11 +843,115 @@ impl ErrorReport {
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
a.write_all(self.as_ref()).await a.write_all(self.as_ref()).await
} }
fn header(&self) -> Header {
Header::from_raw(self.header_bytes()).expect("validated ErrorReport header")
}
fn header_bytes(&self) -> [u8; HEADER_LEN] {
self.octets[..HEADER_LEN]
.try_into()
.expect("ErrorReport shorter than header")
}
fn erroneous_pdu_len(&self) -> usize {
u32::from_be_bytes(
self.octets[Header::LEN..Header::LEN + 4]
.try_into()
.unwrap()
) as usize
}
fn erroneous_pdu_range(&self) -> std::ops::Range<usize> {
let start = Header::LEN + 4;
let end = start + self.erroneous_pdu_len();
start..end
}
fn text_len_offset(&self) -> usize {
self.erroneous_pdu_range().end
}
fn text_len(&self) -> usize {
let offset = self.text_len_offset();
u32::from_be_bytes(
self.octets[offset..offset + 4]
.try_into()
.unwrap()
) as usize
}
fn text_range(&self) -> std::ops::Range<usize> {
let start = self.text_len_offset() + 4;
let end = start + self.text_len();
start..end
}
fn validate(&self) -> Result<(), io::Error> {
let header = Header::from_raw(self.header_bytes())?;
if header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected PDU type for ErrorReport",
));
}
let total_len = header.pdu_len()?;
if total_len != self.octets.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport length mismatch",
));
}
if self.octets.len() < Header::LEN + 8 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport too short",
));
}
let pdu_len = self.erroneous_pdu_len();
let text_len_offset = Header::LEN + 4 + pdu_len;
let Some(text_len_end) = text_len_offset.checked_add(4) else {
return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport length overflow"));
};
if text_len_end > self.octets.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport truncated before error text length",
));
}
let text_len = u32::from_be_bytes(
self.octets[text_len_offset..text_len_end]
.try_into()
.unwrap()
) as usize;
let Some(text_end) = text_len_end.checked_add(text_len) else {
return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport text overflow"));
};
if text_end != self.octets.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport payload length mismatch",
));
}
if std::str::from_utf8(self.text()).is_err() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport text is not valid UTF-8",
));
}
Ok(())
}
} }
// TODO: 补全 // TODO: 补全
// Router Key /// Router Key
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct RouterKey { pub struct RouterKey {
header: HeaderWithFlags, header: HeaderWithFlags,
@ -724,17 +966,13 @@ pub struct RouterKey {
impl RouterKey { impl RouterKey {
pub const PDU: u8 = 9; pub const PDU: u8 = 9;
const BASE_LEN: usize = HEADER_LEN + 20 + 4;
pub async fn write<A: AsyncWrite + Unpin>( pub async fn write<A: AsyncWrite + Unpin>(
&self, &self,
w: &mut A, w: &mut A,
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
let length = Self::BASE_LEN + self.subject_public_key_info.len();
let length = HEADER_LEN
+ 1 // flags
// + self.ski.as_ref().len()
+ 4 // ASN
+ self.subject_public_key_info.len();
let header = HeaderWithFlags::new( let header = HeaderWithFlags::new(
self.header.version(), self.header.version(),
@ -750,13 +988,42 @@ impl RouterKey {
ZERO_8, ZERO_8,
]).await?; ]).await?;
w.write_all(&length.to_be_bytes()).await?; w.write_all(&(length as u32).to_be_bytes()).await?;
// w.write_all(self.ski.as_ref()).await?; w.write_all(self.ski.as_ref()).await?;
w.write_all(&self.asn.into_u32().to_be_bytes()).await?; w.write_all(&self.asn.into_u32().to_be_bytes()).await?;
w.write_all(&self.subject_public_key_info).await?; w.write_all(&self.subject_public_key_info).await?;
Ok(()) Ok(())
} }
pub fn new(
version: u8,
flags: Flags,
ski: Ski,
asn: Asn,
subject_public_key_info: Arc<[u8]>,
) -> Self {
let length = Self::BASE_LEN + subject_public_key_info.len();
Self {
header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32),
flags,
ski,
asn,
subject_public_key_info,
}
}
pub fn ski(&self) -> Ski {
self.ski
}
pub fn asn(&self) -> Asn {
self.asn
}
pub fn spki(&self) -> &[u8] {
&self.subject_public_key_info
}
} }
@ -772,21 +1039,19 @@ pub struct Aspa{
impl Aspa { impl Aspa {
pub const PDU: u8 = 11; pub const PDU: u8 = 11;
const BASE_LEN: usize = HEADER_LEN + 4;
pub async fn write<A: AsyncWrite + Unpin>( pub async fn write<A: AsyncWrite + Unpin>(
&self, &self,
w: &mut A, w: &mut A,
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
let length = HEADER_LEN let length = Self::BASE_LEN + (self.provider_asns.len() * 4);
+ 1
+ 4
+ (self.provider_asns.len() * 4);
let header = HeaderWithFlags::new( let header = HeaderWithFlags::new(
self.header.version(), self.header.version(),
Self::PDU, Self::PDU,
Flags::new(self.header.flags), self.header.flags(),
length as u32, length as u32,
); );
@ -797,7 +1062,7 @@ impl Aspa {
ZERO_8, ZERO_8,
]).await?; ]).await?;
w.write_all(&length.to_be_bytes()).await?; w.write_all(&(length as u32).to_be_bytes()).await?;
w.write_all(&self.customer_asn.to_be_bytes()).await?; w.write_all(&self.customer_asn.to_be_bytes()).await?;
for asn in &self.provider_asns { for asn in &self.provider_asns {
@ -806,6 +1071,20 @@ impl Aspa {
Ok(()) Ok(())
} }
pub fn new(
version: u8,
flags: Flags,
customer_asn: u32,
provider_asns: Vec<u32>,
) -> Self {
let length = Self::BASE_LEN + (provider_asns.len() * 4);
Self {
header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32),
customer_asn,
provider_asns,
}
}
} }
@ -823,54 +1102,3 @@ impl AsMut<[u8]> for ErrorReport {
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn test_serial_notify_roundtrip() {
let (mut client, mut server) = duplex(1024);
let original = SerialNotify::new(1, 42, 100);
// 写入
tokio::spawn(async move {
original.write(&mut client).await.unwrap();
});
// 读取
let decoded = SerialNotify::read(&mut server).await.unwrap();
assert_eq!(decoded.version(), 1);
assert_eq!(decoded.session_id(), 42);
assert_eq!(decoded.serial_number(), 100u32.to_be());
}
#[tokio::test]
async fn test_ipv4_prefix_roundtrip() {
use std::net::Ipv4Addr;
let (mut client, mut server) = duplex(1024);
let prefix = IPv4Prefix::new(
1,
Flags::new(1),
24,
24,
Ipv4Addr::new(192,168,0,0),
65000u32.into(),
);
tokio::spawn(async move {
prefix.write(&mut client).await.unwrap();
});
let decoded = IPv4Prefix::read(&mut server).await.unwrap();
assert_eq!(decoded.prefix_len(), 24);
assert_eq!(decoded.max_len(), 24);
assert_eq!(decoded.prefix(), Ipv4Addr::new(192,168,0,0));
assert_eq!(decoded.flag().is_announce(), true);
}
}

View File

@ -1,7 +1,12 @@
use std::time::Duration;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RtrServiceConfig { pub struct RtrServiceConfig {
pub max_connections: usize, pub max_connections: usize,
pub notify_queue_size: usize, pub notify_queue_size: usize,
pub tcp_keepalive: Option<Duration>,
pub warn_insecure_tcp: bool,
pub require_tls_server_dns_name_san: bool,
} }
impl Default for RtrServiceConfig { impl Default for RtrServiceConfig {
@ -9,6 +14,9 @@ impl Default for RtrServiceConfig {
Self { Self {
max_connections: 1024, max_connections: 1024,
notify_queue_size: 1024, notify_queue_size: 1024,
tcp_keepalive: Some(Duration::from_secs(60)),
warn_insecure_tcp: true,
require_tls_server_dns_name_san: false,
} }
} }
} }

View File

@ -1,13 +1,15 @@
use std::net::SocketAddr; use std::net::{IpAddr, SocketAddr};
use std::sync::{ use std::sync::{
Arc, Arc,
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
}; };
use anyhow::{Context, Result}; use anyhow::{Context, Result, anyhow};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::{broadcast, watch, OwnedSemaphorePermit}; use tokio::sync::{broadcast, watch, OwnedSemaphorePermit};
use tracing::error; use tracing::{error, info, warn};
use x509_parser::extensions::GeneralName;
use x509_parser::prelude::{FromDer, X509Certificate};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
@ -30,6 +32,10 @@ impl ConnectionGuard {
_permit: permit, _permit: permit,
} }
} }
pub fn active_count(&self) -> usize {
self.active_connections.load(Ordering::Relaxed)
}
} }
impl Drop for ConnectionGuard { impl Drop for ConnectionGuard {
@ -52,6 +58,7 @@ pub async fn handle_tcp_connection(
return Err(err); return Err(err);
} }
info!("RTR TCP session completed normally for {}", peer_addr);
Ok(()) Ok(())
} }
@ -63,10 +70,15 @@ pub async fn handle_tls_connection(
notify_rx: broadcast::Receiver<()>, notify_rx: broadcast::Receiver<()>,
shutdown_rx: watch::Receiver<bool>, shutdown_rx: watch::Receiver<bool>,
) -> Result<()> { ) -> Result<()> {
info!("RTR TLS handshake started for {}", peer_addr);
let tls_stream = acceptor let tls_stream = acceptor
.accept(stream) .accept(stream)
.await .await
.with_context(|| format!("TLS handshake failed for {}", peer_addr))?; .with_context(|| format!("TLS handshake failed for {}", peer_addr))?;
info!("RTR TLS handshake completed for {}", peer_addr);
verify_peer_certificate_ip(&tls_stream, peer_addr.ip())
.with_context(|| format!("TLS client certificate SAN IP validation failed for {}", peer_addr))?;
info!("RTR TLS client certificate validated for {}", peer_addr);
let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx); let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx);
@ -75,5 +87,57 @@ pub async fn handle_tls_connection(
return Err(err); return Err(err);
} }
info!("RTR TLS session completed normally for {}", peer_addr);
Ok(()) Ok(())
} }
fn verify_peer_certificate_ip(
tls_stream: &tokio_rustls::server::TlsStream<TcpStream>,
peer_ip: IpAddr,
) -> Result<()> {
let (_, server_connection) = tls_stream.get_ref();
let peer_certs = server_connection
.peer_certificates()
.ok_or_else(|| anyhow!("missing peer certificate after TLS client authentication"))?;
let end_entity = peer_certs
.first()
.ok_or_else(|| anyhow!("peer did not present an end-entity certificate"))?;
let (_, cert) = X509Certificate::from_der(end_entity.as_ref())
.map_err(|err| anyhow!("failed to parse peer certificate: {:?}", err))?;
let san = cert
.subject_alternative_name()
.map_err(|err| anyhow!("failed to parse peer certificate SAN: {:?}", err))?
.ok_or_else(|| anyhow!("peer certificate is missing subjectAltName"))?;
let matched = san.value.general_names.iter().any(|name| match name {
GeneralName::IPAddress(bytes) => {
let bytes = *bytes;
match (peer_ip, bytes.len()) {
(IpAddr::V4(ip), 4) => <[u8; 4]>::try_from(bytes)
.map(IpAddr::from)
.map(|cert_ip| cert_ip == IpAddr::V4(ip))
.unwrap_or(false),
(IpAddr::V6(ip), 16) => <[u8; 16]>::try_from(bytes)
.map(IpAddr::from)
.map(|cert_ip| cert_ip == IpAddr::V6(ip))
.unwrap_or(false),
_ => false,
}
}
_ => false,
});
if matched {
Ok(())
} else {
warn!(
"RTR TLS client certificate SAN IP mismatch for peer_ip={}",
peer_ip
);
Err(anyhow!(
"peer certificate subjectAltName iPAddress does not match {}",
peer_ip
))
}
}

View File

@ -4,8 +4,10 @@ use std::sync::{
Arc, Arc,
atomic::AtomicUsize, atomic::AtomicUsize,
}; };
use std::time::Duration;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use socket2::{SockRef, TcpKeepalive};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::{broadcast, watch, Semaphore}; use tokio::sync::{broadcast, watch, Semaphore};
use tracing::{info, warn}; use tracing::{info, warn};
@ -15,7 +17,8 @@ use tokio_rustls::TlsAcceptor;
use crate::rtr::cache::SharedRtrCache; use crate::rtr::cache::SharedRtrCache;
use crate::rtr::server::connection::{ConnectionGuard, handle_tcp_connection, handle_tls_connection}; use crate::rtr::server::connection::{ConnectionGuard, handle_tcp_connection, handle_tls_connection};
use crate::rtr::server::tls::load_rustls_server_config; use crate::rtr::server::config::RtrServiceConfig;
use crate::rtr::server::tls::load_rustls_server_config_with_options;
pub struct RtrServer { pub struct RtrServer {
bind_addr: SocketAddr, bind_addr: SocketAddr,
@ -24,6 +27,7 @@ pub struct RtrServer {
shutdown_tx: watch::Sender<bool>, shutdown_tx: watch::Sender<bool>,
connection_limiter: Arc<Semaphore>, connection_limiter: Arc<Semaphore>,
active_connections: Arc<AtomicUsize>, active_connections: Arc<AtomicUsize>,
config: RtrServiceConfig,
} }
impl RtrServer { impl RtrServer {
@ -34,6 +38,7 @@ impl RtrServer {
shutdown_tx: watch::Sender<bool>, shutdown_tx: watch::Sender<bool>,
connection_limiter: Arc<Semaphore>, connection_limiter: Arc<Semaphore>,
active_connections: Arc<AtomicUsize>, active_connections: Arc<AtomicUsize>,
config: RtrServiceConfig,
) -> Self { ) -> Self {
Self { Self {
bind_addr, bind_addr,
@ -42,6 +47,7 @@ impl RtrServer {
shutdown_tx, shutdown_tx,
connection_limiter, connection_limiter,
active_connections, active_connections,
config,
} }
} }
@ -95,6 +101,9 @@ impl RtrServer {
if let Err(err) = stream.set_nodelay(true) { if let Err(err) = stream.set_nodelay(true) {
warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err); warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err);
} }
if let Err(err) = apply_keepalive(&stream, self.config.tcp_keepalive) {
warn!("failed to configure TCP keepalive for {}: {}", peer_addr, err);
}
let permit = match self.connection_limiter.clone().try_acquire_owned() { let permit = match self.connection_limiter.clone().try_acquire_owned() {
Ok(permit) => permit, Ok(permit) => permit,
@ -102,7 +111,7 @@ impl RtrServer {
warn!( warn!(
"RTR TCP connection rejected for {}: max connections reached ({})", "RTR TCP connection rejected for {}: max connections reached ({})",
peer_addr, peer_addr,
self.connection_limiter.available_permits() self.config.max_connections
); );
drop(stream); drop(stream);
continue; continue;
@ -114,16 +123,34 @@ impl RtrServer {
let shutdown_rx = self.shutdown_tx.subscribe(); let shutdown_rx = self.shutdown_tx.subscribe();
let active_connections = self.active_connections.clone(); let active_connections = self.active_connections.clone();
info!("RTR TCP client connected: {}", peer_addr); info!(
"RTR TCP client connected: peer_addr={}, active_connections(before_spawn)={}",
peer_addr,
self.active_connections()
);
tokio::spawn(async move { tokio::spawn(async move {
let _guard = ConnectionGuard::new(active_connections, permit); let guard = ConnectionGuard::new(active_connections, permit);
info!(
"RTR TCP connection established: peer_addr={}, active_connections={}",
peer_addr,
guard.active_count()
);
if let Err(err) = if let Err(err) =
handle_tcp_connection(cache, stream, peer_addr, notify_rx, shutdown_rx).await handle_tcp_connection(cache, stream, peer_addr, notify_rx, shutdown_rx).await
{ {
warn!("RTR TCP session {} ended with error: {:?}", peer_addr, err); warn!(
"RTR TCP session closed with error: peer_addr={}, active_connections={}, err={:?}",
peer_addr,
guard.active_count(),
err
);
} else { } else {
info!("RTR TCP session {} closed", peer_addr); info!(
"RTR TCP session closed cleanly: peer_addr={}, active_connections={}",
peer_addr,
guard.active_count()
);
} }
}); });
} }
@ -135,8 +162,14 @@ impl RtrServer {
self, self,
cert_path: impl AsRef<Path>, cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>, key_path: impl AsRef<Path>,
client_ca_path: impl AsRef<Path>,
) -> Result<()> { ) -> Result<()> {
let tls_config = Arc::new(load_rustls_server_config(cert_path, key_path)?); let tls_config = Arc::new(load_rustls_server_config_with_options(
cert_path,
key_path,
client_ca_path,
self.config.require_tls_server_dns_name_san,
)?);
self.run_tls(tls_config).await self.run_tls(tls_config).await
} }
@ -179,11 +212,18 @@ impl RtrServer {
if let Err(err) = stream.set_nodelay(true) { if let Err(err) = stream.set_nodelay(true) {
warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err); warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err);
} }
if let Err(err) = apply_keepalive(&stream, self.config.tcp_keepalive) {
warn!("failed to configure TCP keepalive for {}: {}", peer_addr, err);
}
let permit = match self.connection_limiter.clone().try_acquire_owned() { let permit = match self.connection_limiter.clone().try_acquire_owned() {
Ok(permit) => permit, Ok(permit) => permit,
Err(_) => { Err(_) => {
warn!("RTR TLS connection rejected for {}: max connections reached", peer_addr); warn!(
"RTR TLS connection rejected for {}: max connections reached ({})",
peer_addr,
self.config.max_connections
);
drop(stream); drop(stream);
continue; continue;
} }
@ -195,10 +235,19 @@ impl RtrServer {
let shutdown_rx = self.shutdown_tx.subscribe(); let shutdown_rx = self.shutdown_tx.subscribe();
let active_connections = self.active_connections.clone(); let active_connections = self.active_connections.clone();
info!("RTR TLS client connected: {}", peer_addr); info!(
"RTR TLS client connected: peer_addr={}, active_connections(before_spawn)={}",
peer_addr,
self.active_connections()
);
tokio::spawn(async move { tokio::spawn(async move {
let _guard = ConnectionGuard::new(active_connections, permit); let guard = ConnectionGuard::new(active_connections, permit);
info!(
"RTR TLS connection established: peer_addr={}, active_connections={}",
peer_addr,
guard.active_count()
);
if let Err(err) = handle_tls_connection( if let Err(err) = handle_tls_connection(
cache, cache,
stream, stream,
@ -207,13 +256,38 @@ impl RtrServer {
notify_rx, notify_rx,
shutdown_rx, shutdown_rx,
).await { ).await {
warn!("RTR TLS session {} ended with error: {:?}", peer_addr, err); warn!(
"RTR TLS session closed with error: peer_addr={}, active_connections={}, err={:?}",
peer_addr,
guard.active_count(),
err
);
} else { } else {
info!("RTR TLS session {} closed", peer_addr); info!(
"RTR TLS session closed cleanly: peer_addr={}, active_connections={}",
peer_addr,
guard.active_count()
);
} }
}); });
} }
} }
} }
} }
} }
fn apply_keepalive(
stream: &tokio::net::TcpStream,
keepalive: Option<Duration>,
) -> Result<()> {
let Some(keepalive) = keepalive else {
return Ok(());
};
let socket = SockRef::from(stream);
let keepalive = TcpKeepalive::new().with_time(keepalive);
socket
.set_tcp_keepalive(&keepalive)
.context("unable to apply TCP keepalive settings")?;
Ok(())
}

View File

@ -7,7 +7,7 @@ use std::sync::{
use tokio::sync::{broadcast, watch, Semaphore}; use tokio::sync::{broadcast, watch, Semaphore};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tracing::error; use tracing::{error, warn};
use crate::rtr::cache::SharedRtrCache; use crate::rtr::cache::SharedRtrCache;
use crate::rtr::server::config::RtrServiceConfig; use crate::rtr::server::config::RtrServiceConfig;
@ -70,6 +70,7 @@ impl RtrService {
self.shutdown_tx.clone(), self.shutdown_tx.clone(),
self.connection_limiter.clone(), self.connection_limiter.clone(),
self.active_connections.clone(), self.active_connections.clone(),
self.config.clone(),
) )
} }
@ -81,10 +82,17 @@ impl RtrService {
self.shutdown_tx.clone(), self.shutdown_tx.clone(),
self.connection_limiter.clone(), self.connection_limiter.clone(),
self.active_connections.clone(), self.active_connections.clone(),
self.config.clone(),
) )
} }
pub fn spawn_tcp(&self, bind_addr: SocketAddr) -> JoinHandle<()> { pub fn spawn_tcp(&self, bind_addr: SocketAddr) -> JoinHandle<()> {
if self.config.warn_insecure_tcp {
warn!(
"starting plain TCP RTR service on {}. Per draft-ietf-sidrops-8210bis-25 Section 9, unsecured TCP must only be used on a trusted and controlled network",
bind_addr
);
}
let server = self.tcp_server(bind_addr); let server = self.tcp_server(bind_addr);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = server.run_tcp().await { if let Err(err) = server.run_tcp().await {
@ -98,13 +106,15 @@ impl RtrService {
bind_addr: SocketAddr, bind_addr: SocketAddr,
cert_path: impl AsRef<Path>, cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>, key_path: impl AsRef<Path>,
client_ca_path: impl AsRef<Path>,
) -> JoinHandle<()> { ) -> JoinHandle<()> {
let cert_path = cert_path.as_ref().to_path_buf(); let cert_path = cert_path.as_ref().to_path_buf();
let key_path = key_path.as_ref().to_path_buf(); let key_path = key_path.as_ref().to_path_buf();
let client_ca_path = client_ca_path.as_ref().to_path_buf();
let server = self.tls_server(bind_addr); let server = self.tls_server(bind_addr);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = server.run_tls_from_pem(cert_path, key_path).await { if let Err(err) = server.run_tls_from_pem(cert_path, key_path, client_ca_path).await {
error!("RTR TLS server {} exited with error: {:?}", bind_addr, err); error!("RTR TLS server {} exited with error: {:?}", bind_addr, err);
} }
}) })
@ -116,9 +126,10 @@ impl RtrService {
tls_bind_addr: SocketAddr, tls_bind_addr: SocketAddr,
cert_path: impl AsRef<Path>, cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>, key_path: impl AsRef<Path>,
client_ca_path: impl AsRef<Path>,
) -> RunningRtrService { ) -> RunningRtrService {
let tcp_handle = self.spawn_tcp(tcp_bind_addr); let tcp_handle = self.spawn_tcp(tcp_bind_addr);
let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path); let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path, client_ca_path);
RunningRtrService { RunningRtrService {
shutdown_tx: self.shutdown_tx.clone(), shutdown_tx: self.shutdown_tx.clone(),
@ -151,4 +162,4 @@ impl RunningRtrService {
let _ = handle.await; let _ = handle.await;
} }
} }
} }

View File

@ -1,32 +1,101 @@
use std::fs::File; use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use rustls::ServerConfig; use rustls::server::WebPkiClientVerifier;
use rustls::{RootCertStore, ServerConfig};
use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use tracing::warn;
use x509_parser::extensions::GeneralName;
use x509_parser::prelude::{FromDer, X509Certificate};
pub fn load_rustls_server_config( pub fn load_rustls_server_config(
cert_path: impl AsRef<Path>, cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>, key_path: impl AsRef<Path>,
client_ca_path: impl AsRef<Path>,
) -> Result<ServerConfig> {
load_rustls_server_config_with_options(cert_path, key_path, client_ca_path, false)
}
pub fn load_rustls_server_config_with_options(
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
client_ca_path: impl AsRef<Path>,
require_dns_name_san: bool,
) -> Result<ServerConfig> { ) -> Result<ServerConfig> {
let cert_path: PathBuf = cert_path.as_ref().to_path_buf(); let cert_path: PathBuf = cert_path.as_ref().to_path_buf();
let key_path: PathBuf = key_path.as_ref().to_path_buf(); let key_path: PathBuf = key_path.as_ref().to_path_buf();
let client_ca_path: PathBuf = client_ca_path.as_ref().to_path_buf();
let certs = load_certs(&cert_path) let certs = load_certs(&cert_path)
.with_context(|| format!("failed to load certs from {}", cert_path.display()))?; .with_context(|| format!("failed to load certs from {}", cert_path.display()))?;
validate_server_certificate_dns_name_san(&certs, &cert_path, require_dns_name_san)?;
let key = load_private_key(&key_path) let key = load_private_key(&key_path)
.with_context(|| format!("failed to load private key from {}", key_path.display()))?; .with_context(|| format!("failed to load private key from {}", key_path.display()))?;
let client_ca_certs = load_certs(&client_ca_path)
.with_context(|| format!("failed to load client CA certs from {}", client_ca_path.display()))?;
let mut client_roots = RootCertStore::empty();
let (added, _) = client_roots.add_parsable_certificates(client_ca_certs);
if added == 0 {
return Err(anyhow!(
"no valid client CA certificates found in {}",
client_ca_path.display()
));
}
let client_verifier = WebPkiClientVerifier::builder(Arc::new(client_roots))
.build()
.map_err(|e| anyhow!("invalid client certificate verifier configuration: {}", e))?;
let config = ServerConfig::builder() let config = ServerConfig::builder()
.with_no_client_auth() .with_client_cert_verifier(client_verifier)
.with_single_cert(certs, key) .with_single_cert(certs, key)
.map_err(|e| anyhow!("invalid certificate/key pair: {}", e))?; .map_err(|e| anyhow!("invalid certificate/key pair: {}", e))?;
Ok(config) Ok(config)
} }
fn validate_server_certificate_dns_name_san(
certs: &[CertificateDer<'static>],
cert_path: &Path,
require_dns_name_san: bool,
) -> Result<()> {
let leaf = certs
.first()
.ok_or_else(|| anyhow!("missing end-entity certificate in {}", cert_path.display()))?;
let (_, cert) = X509Certificate::from_der(leaf.as_ref())
.map_err(|err| anyhow!("failed to parse server certificate: {:?}", err))?;
let has_dns_name_san = cert
.subject_alternative_name()
.map_err(|err| anyhow!("failed to parse server certificate SAN: {:?}", err))?
.map(|san| {
san.value
.general_names
.iter()
.any(|name| matches!(name, GeneralName::DNSName(_)))
})
.unwrap_or(false);
if has_dns_name_san {
return Ok(());
}
let message = format!(
"server certificate {} does not contain a subjectAltName dNSName entry; draft-ietf-sidrops-8210bis-25 Section 9.2 requires routers to authenticate the cache using DNS-ID rather than CN-ID",
cert_path.display()
);
if require_dns_name_san {
Err(anyhow!(message))
} else {
warn!("{}", message);
Ok(())
}
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> { fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
let file = File::open(path)?; let file = File::open(path)?;
let mut reader = BufReader::new(file); let mut reader = BufReader::new(file);
@ -49,4 +118,4 @@ fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
.ok_or_else(|| anyhow!("no private key found in {}", path.display()))?; .ok_or_else(|| anyhow!("no private key found in {}", path.display()))?;
Ok(key) Ok(key)
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +1,16 @@
use crate::rtr::cache::SessionIds;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct State { pub struct State {
session_id: u16, session_ids: SessionIds,
serial: u32, serial: u32,
} }
impl State { impl State {
pub fn session_id(self) -> u16 { pub fn session_ids(self) -> SessionIds {
self.session_id self.session_ids
} }
pub fn serial(self) -> u32 { pub fn serial(self) -> u32 {

619
src/rtr/store.rs Normal file
View File

@ -0,0 +1,619 @@
use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch};
use anyhow::{anyhow, Result};
use serde::{de::DeserializeOwned, Serialize};
use std::path::Path;
use std::sync::Arc;
use tokio::task;
use tracing::{debug, info, warn};
use crate::rtr::cache::{CacheAvailability, Delta, SessionIds, Snapshot};
use crate::rtr::state::State;
const CF_META: &str = "meta";
const CF_SNAPSHOT: &str = "snapshot";
const CF_DELTA: &str = "delta";
const META_STATE: &[u8] = b"state";
const META_SESSION_IDS: &[u8] = b"session_ids";
const META_SERIAL: &[u8] = b"serial";
const META_AVAILABILITY: &[u8] = b"availability";
const META_DELTA_MIN: &[u8] = b"delta_min";
const META_DELTA_MAX: &[u8] = b"delta_max";
const DELTA_KEY_PREFIX: u8 = b'd';
fn delta_key(serial: u32) -> [u8; 5] {
let mut key = [0u8; 5];
key[0] = DELTA_KEY_PREFIX;
key[1..].copy_from_slice(&serial.to_be_bytes());
key
}
fn delta_key_serial(key: &[u8]) -> Option<u32> {
if key.len() != 5 || key[0] != DELTA_KEY_PREFIX {
return None;
}
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&key[1..]);
Some(u32::from_be_bytes(bytes))
}
#[derive(Clone)]
pub struct RtrStore {
db: Arc<DB>,
}
impl RtrStore {
/// Open or create DB with required column families.
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_ref = path.as_ref();
let mut opts = Options::default();
opts.create_if_missing(true);
opts.create_missing_column_families(true);
let cfs = vec![
ColumnFamilyDescriptor::new(CF_META, Options::default()),
ColumnFamilyDescriptor::new(CF_SNAPSHOT, Options::default()),
ColumnFamilyDescriptor::new(CF_DELTA, Options::default()),
];
info!("opening RTR RocksDB store at {}", path_ref.display());
let db = Arc::new(DB::open_cf_descriptors(&opts, path_ref, cfs)?);
info!("opened RTR RocksDB store at {}", path_ref.display());
Ok(Self { db })
}
/// Common serialize/put.
fn put_cf<T: Serialize>(&self, cf: &str, key: &[u8], value: &T) -> Result<()> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?;
let data = serde_json::to_vec(value)?;
self.db.put_cf(cf_handle, key, data)?;
Ok(())
}
/// Common get/deserialize.
fn get_cf<T: DeserializeOwned>(&self, cf: &str, key: &[u8]) -> Result<Option<T>> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?;
if let Some(value) = self.db.get_cf(cf_handle, key)? {
let obj = serde_json::from_slice(&value)?;
Ok(Some(obj))
} else {
Ok(None)
}
}
/// Common delete.
fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?;
self.db.delete_cf(cf_handle, key)?;
Ok(())
}
// ===============================
// Meta/state
// ===============================
pub fn set_state(&self, state: &State) -> Result<()> {
self.put_cf(CF_META, META_STATE, &state)
}
pub fn get_state(&self) -> Result<Option<State>> {
self.get_cf(CF_META, META_STATE)
}
pub fn set_meta(&self, meta: &State) -> Result<()> {
self.set_state(meta)
}
pub fn get_meta(&self) -> Result<Option<State>> {
self.get_state()
}
pub fn set_session_ids(&self, session_ids: &SessionIds) -> Result<()> {
self.put_cf(CF_META, META_SESSION_IDS, session_ids)
}
pub fn get_session_ids(&self) -> Result<Option<SessionIds>> {
self.get_cf(CF_META, META_SESSION_IDS)
}
pub fn set_serial(&self, serial: u32) -> Result<()> {
self.put_cf(CF_META, META_SERIAL, &serial)
}
pub fn get_serial(&self) -> Result<Option<u32>> {
self.get_cf(CF_META, META_SERIAL)
}
pub fn set_availability(&self, availability: CacheAvailability) -> Result<()> {
self.put_cf(CF_META, META_AVAILABILITY, &availability)
}
pub fn get_availability(&self) -> Result<Option<CacheAvailability>> {
self.get_cf(CF_META, META_AVAILABILITY)
}
pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> {
debug!(
"RTR store persisting delta window metadata: min_serial={}, max_serial={}",
min_serial,
max_serial
);
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default();
batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?);
batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?);
self.db.write(batch)?;
Ok(())
}
pub fn clear_delta_window(&self) -> Result<()> {
debug!("RTR store clearing delta window metadata");
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default();
batch.delete_cf(meta_cf, META_DELTA_MIN);
batch.delete_cf(meta_cf, META_DELTA_MAX);
self.db.write(batch)?;
Ok(())
}
pub fn get_delta_window(&self) -> Result<Option<(u32, u32)>> {
let min: Option<u32> = self.get_cf(CF_META, META_DELTA_MIN)?;
let max: Option<u32> = self.get_cf(CF_META, META_DELTA_MAX)?;
match (min, max) {
(Some(min), Some(max)) => {
debug!(
"RTR store loaded delta window metadata: min_serial={}, max_serial={}",
min,
max
);
Ok(Some((min, max)))
}
(None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: delta window mismatch")),
}
}
pub fn delete_state(&self) -> Result<()> {
self.delete_cf(CF_META, META_STATE)
}
pub fn delete_serial(&self) -> Result<()> {
self.delete_cf(CF_META, META_SERIAL)
}
// ===============================
// Snapshot
// ===============================
pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
let cf_handle = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let mut batch = WriteBatch::default();
let data = serde_json::to_vec(snapshot)?;
batch.put_cf(cf_handle, b"current", data);
self.db.write(batch)?;
Ok(())
}
pub fn get_snapshot(&self) -> Result<Option<Snapshot>> {
self.get_cf(CF_SNAPSHOT, b"current")
}
pub fn delete_snapshot(&self) -> Result<()> {
self.delete_cf(CF_SNAPSHOT, b"current")
}
pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> {
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default();
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_STATE, serde_json::to_vec(state)?);
batch.put_cf(
meta_cf,
META_SESSION_IDS,
serde_json::to_vec(&state.clone().session_ids())?,
);
batch.put_cf(
meta_cf,
META_SERIAL,
serde_json::to_vec(&state.clone().serial())?,
);
self.db.write(batch)?;
Ok(())
}
pub fn save_snapshot_and_meta(
&self,
snapshot: &Snapshot,
session_ids: &SessionIds,
serial: u32,
) -> Result<()> {
let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?);
batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?);
self.db.write(batch)?;
Ok(())
}
pub fn save_cache_state(
&self,
availability: CacheAvailability,
snapshot: &Snapshot,
session_ids: &SessionIds,
serial: u32,
delta: Option<&Delta>,
delta_window: Option<(u32, u32)>,
clear_delta_window: bool,
) -> Result<()> {
debug!(
"RTR store save_cache_state start: availability={:?}, serial={}, session_ids={:?}, delta_present={}, delta_window={:?}, clear_delta_window={}, snapshot(route_origins={}, router_keys={}, aspas={})",
availability,
serial,
session_ids,
delta.is_some(),
delta_window,
clear_delta_window,
snapshot.origins().len(),
snapshot.router_keys().len(),
snapshot.aspas().len()
);
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
let delta_cf = self.db.cf_handle(CF_DELTA).ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let mut batch = WriteBatch::default();
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?);
batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?);
batch.put_cf(meta_cf, META_AVAILABILITY, serde_json::to_vec(&availability)?);
if let Some(delta) = delta {
debug!(
"RTR store persisting delta: serial={}, announced={}, withdrawn={}",
delta.serial(),
delta.announced().len(),
delta.withdrawn().len()
);
batch.put_cf(delta_cf, delta_key(delta.serial()), serde_json::to_vec(delta)?);
}
if clear_delta_window {
let existing_keys = self.list_delta_keys()?;
let existing_serials = summarize_delta_keys(&existing_keys);
info!(
"RTR store clearing persisted delta window: deleting {} delta records, serials={}",
existing_keys.len(),
existing_serials
);
batch.delete_cf(meta_cf, META_DELTA_MIN);
batch.delete_cf(meta_cf, META_DELTA_MAX);
for key in existing_keys {
batch.delete_cf(delta_cf, key);
}
} else if let Some((min_serial, max_serial)) = delta_window {
batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?);
batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?);
// Serial numbers are compared in RFC 1982 ring order, while RocksDB stores
// keys in plain lexicographic order. After wraparound, a window such as
// [u32::MAX, 1] is contiguous in serial space but split in key space, so a
// simple delete_range() would leave stale high-serial keys behind.
let stale_keys = self.list_delta_keys_outside_window(min_serial, max_serial)?;
if !stale_keys.is_empty() {
info!(
"RTR store pruning stale delta records outside window [{}, {}]: count={}, serials={}",
min_serial,
max_serial,
stale_keys.len(),
summarize_delta_keys(&stale_keys)
);
} else {
debug!(
"RTR store found no stale delta records outside window [{}, {}]",
min_serial,
max_serial
);
}
for key in stale_keys {
batch.delete_cf(delta_cf, key);
}
}
self.db.write(batch)?;
debug!("RTR store save_cache_state completed: serial={}", serial);
Ok(())
}
pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> {
let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?);
self.db.write(batch)?;
Ok(())
}
pub async fn save_snapshot_and_serial_async(
self: Arc<Self>,
snapshot: Snapshot,
serial: u32,
) -> Result<()> {
let snapshot_bytes = serde_json::to_vec(&snapshot)?;
let serial_bytes = serde_json::to_vec(&serial)?;
task::spawn_blocking(move || {
let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", snapshot_bytes);
batch.put_cf(meta_cf, META_SERIAL, serial_bytes);
self.db.write(batch)?;
Ok::<_, anyhow::Error>(())
})
.await??;
Ok(())
}
pub fn load_snapshot_and_state(&self) -> Result<Option<(Snapshot, State)>> {
let snapshot: Option<Snapshot> = self.get_snapshot()?;
let state: Option<State> = self.get_state()?;
match (snapshot, state) {
(Some(snap), Some(state)) => Ok(Some((snap, state))),
(None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: snapshot and state mismatch")),
}
}
pub fn load_snapshot_and_serial(&self) -> Result<Option<(Snapshot, u32)>> {
let snapshot: Option<Snapshot> = self.get_snapshot()?;
let serial: Option<u32> = self.get_serial()?;
match (snapshot, serial) {
(Some(snap), Some(serial)) => Ok(Some((snap, serial))),
(None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: snapshot and serial mismatch")),
}
}
// ===============================
// Delta
// ===============================
pub fn save_delta(&self, delta: &Delta) -> Result<()> {
self.put_cf(CF_DELTA, &delta_key(delta.serial()), delta)
}
pub fn get_delta(&self, serial: u32) -> Result<Option<Delta>> {
self.get_cf(CF_DELTA, &delta_key(serial))
}
pub fn load_deltas_since(&self, serial: u32) -> Result<Vec<Delta>> {
let cf_handle = self
.db
.cf_handle(CF_DELTA)
.ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let start_key = delta_key(serial.wrapping_add(1));
let iter = self.db.iterator_cf(
cf_handle,
IteratorMode::From(&start_key, Direction::Forward),
);
let mut out = Vec::new();
for item in iter {
let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let parsed = delta_key_serial(key.as_ref())
.ok_or_else(|| anyhow!("Invalid delta key"))?;
if parsed <= serial {
continue;
}
let delta: Delta = serde_json::from_slice(value.as_ref())?;
out.push(delta);
}
Ok(out)
}
pub fn load_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<Vec<Delta>> {
info!(
"RTR store loading persisted delta window: min_serial={}, max_serial={}",
min_serial,
max_serial
);
let cf_handle = self
.db
.cf_handle(CF_DELTA)
.ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let iter = self.db.iterator_cf(cf_handle, IteratorMode::Start);
let mut out = Vec::new();
for item in iter {
let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let parsed = delta_key_serial(key.as_ref())
.ok_or_else(|| anyhow!("Invalid delta key"))?;
// Restore by the persisted window bounds instead of load_deltas_since().
// The latter follows lexicographic key order and is not safe across serial
// wraparound, where older high-valued keys may otherwise be pulled back in.
if serial_in_window(parsed, min_serial, max_serial) {
let delta: Delta = serde_json::from_slice(value.as_ref())?;
out.push(delta);
}
}
out.sort_by_key(|delta| delta.serial().wrapping_sub(min_serial));
debug!(
"RTR store loaded delta candidates for window [{}, {}]: count={}, serials={}",
min_serial,
max_serial,
out.len(),
summarize_delta_serials(&out)
);
validate_delta_window(&out, min_serial, max_serial)?;
info!(
"RTR store restored valid delta window: min_serial={}, max_serial={}, count={}, serials={}",
min_serial,
max_serial,
out.len(),
summarize_delta_serials(&out)
);
Ok(out)
}
pub fn delete_delta(&self, serial: u32) -> Result<()> {
self.delete_cf(CF_DELTA, &delta_key(serial))
}
fn list_delta_keys(&self) -> Result<Vec<Vec<u8>>> {
let cf_handle = self
.db
.cf_handle(CF_DELTA)
.ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let iter = self.db.iterator_cf(cf_handle, IteratorMode::Start);
let mut keys = Vec::new();
for item in iter {
let (key, _value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
keys.push(key.to_vec());
}
Ok(keys)
}
fn list_delta_keys_outside_window(&self, min_serial: u32, max_serial: u32) -> Result<Vec<Vec<u8>>> {
let cf_handle = self
.db
.cf_handle(CF_DELTA)
.ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let iter = self.db.iterator_cf(cf_handle, IteratorMode::Start);
let mut keys = Vec::new();
for item in iter {
let (key, _value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let serial = delta_key_serial(key.as_ref())
.ok_or_else(|| anyhow!("Invalid delta key"))?;
if !serial_in_window(serial, min_serial, max_serial) {
keys.push(key.to_vec());
}
}
Ok(keys)
}
}
fn serial_in_window(serial: u32, min_serial: u32, max_serial: u32) -> bool {
serial.wrapping_sub(min_serial) <= max_serial.wrapping_sub(min_serial)
}
fn validate_delta_window(deltas: &[Delta], min_serial: u32, max_serial: u32) -> Result<()> {
if deltas.is_empty() {
warn!(
"RTR store delta window validation failed: no persisted deltas for window [{}, {}]",
min_serial,
max_serial
);
return Err(anyhow!(
"delta window [{}, {}] has no persisted deltas",
min_serial,
max_serial
));
}
if deltas.first().map(Delta::serial) != Some(min_serial) {
warn!(
"RTR store delta window validation failed: first delta mismatch for window [{}, {}], got {:?}",
min_serial,
max_serial,
deltas.first().map(Delta::serial)
);
return Err(anyhow!(
"delta window starts at {}, but first persisted delta is {:?}",
min_serial,
deltas.first().map(Delta::serial)
));
}
if deltas.last().map(Delta::serial) != Some(max_serial) {
warn!(
"RTR store delta window validation failed: last delta mismatch for window [{}, {}], got {:?}",
min_serial,
max_serial,
deltas.last().map(Delta::serial)
);
return Err(anyhow!(
"delta window ends at {}, but last persisted delta is {:?}",
max_serial,
deltas.last().map(Delta::serial)
));
}
for pair in deltas.windows(2) {
if pair[1].serial() != pair[0].serial().wrapping_add(1) {
warn!(
"RTR store delta window validation failed: non-contiguous pair {} -> {} within window [{}, {}]",
pair[0].serial(),
pair[1].serial(),
min_serial,
max_serial
);
return Err(anyhow!(
"persisted deltas are not contiguous: {} -> {}",
pair[0].serial(),
pair[1].serial()
));
}
}
Ok(())
}
fn summarize_delta_keys(keys: &[Vec<u8>]) -> String {
let serials: Vec<u32> = keys
.iter()
.filter_map(|key| delta_key_serial(key))
.collect();
summarize_serials(&serials)
}
fn summarize_delta_serials(deltas: &[Delta]) -> String {
let serials: Vec<u32> = deltas.iter().map(Delta::serial).collect();
summarize_serials(&serials)
}
fn summarize_serials(serials: &[u32]) -> String {
const MAX_INLINE: usize = 12;
if serials.is_empty() {
return "[]".to_string();
}
if serials.len() <= MAX_INLINE {
return format!("{:?}", serials);
}
let head: Vec<u32> = serials.iter().take(6).copied().collect();
let tail: Vec<u32> = serials
.iter()
.rev()
.take(3)
.copied()
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
format!("{:?} ... {:?} (total={})", head, tail, serials.len())
}

View File

@ -1,310 +0,0 @@
use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch};
use anyhow::{anyhow, Result};
use serde::{de::DeserializeOwned, Serialize};
use std::path::Path;
use std::sync::Arc;
use tokio::task;
use crate::rtr::cache::{Delta, Snapshot};
use crate::rtr::state::State;
const CF_META: &str = "meta";
const CF_SNAPSHOT: &str = "snapshot";
const CF_DELTA: &str = "delta";
const META_STATE: &[u8] = b"state";
const META_SESSION_ID: &[u8] = b"session_id";
const META_SERIAL: &[u8] = b"serial";
const META_DELTA_MIN: &[u8] = b"delta_min";
const META_DELTA_MAX: &[u8] = b"delta_max";
const DELTA_KEY_PREFIX: u8 = b'd';
fn delta_key(serial: u32) -> [u8; 5] {
let mut key = [0u8; 5];
key[0] = DELTA_KEY_PREFIX;
key[1..].copy_from_slice(&serial.to_be_bytes());
key
}
fn delta_key_serial(key: &[u8]) -> Option<u32> {
if key.len() != 5 || key[0] != DELTA_KEY_PREFIX {
return None;
}
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&key[1..]);
Some(u32::from_be_bytes(bytes))
}
#[derive(Clone)]
pub struct RtrStore {
db: Arc<DB>,
}
impl RtrStore {
/// Open or create DB with required column families.
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut opts = Options::default();
opts.create_if_missing(true);
opts.create_missing_column_families(true);
let cfs = vec![
ColumnFamilyDescriptor::new(CF_META, Options::default()),
ColumnFamilyDescriptor::new(CF_SNAPSHOT, Options::default()),
ColumnFamilyDescriptor::new(CF_DELTA, Options::default()),
];
let db = Arc::new(DB::open_cf_descriptors(&opts, path, cfs)?);
Ok(Self { db })
}
/// Common serialize/put.
fn put_cf<T: Serialize>(&self, cf: &str, key: &[u8], value: &T) -> Result<()> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?;
let data = serde_json::to_vec(value)?;
self.db.put_cf(cf_handle, key, data)?;
Ok(())
}
/// Common get/deserialize.
fn get_cf<T: DeserializeOwned>(&self, cf: &str, key: &[u8]) -> Result<Option<T>> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?;
if let Some(value) = self.db.get_cf(cf_handle, key)? {
let obj = serde_json::from_slice(&value)?;
Ok(Some(obj))
} else {
Ok(None)
}
}
/// Common delete.
fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?;
self.db.delete_cf(cf_handle, key)?;
Ok(())
}
// ===============================
// Meta/state
// ===============================
pub fn set_state(&self, state: &State) -> Result<()> {
self.put_cf(CF_META, META_STATE, &state)
}
pub fn get_state(&self) -> Result<Option<State>> {
self.get_cf(CF_META, META_STATE)
}
pub fn set_meta(&self, meta: &State) -> Result<()> {
self.set_state(meta)
}
pub fn get_meta(&self) -> Result<Option<State>> {
self.get_state()
}
pub fn set_session_id(&self, session_id: u16) -> Result<()> {
self.put_cf(CF_META, META_SESSION_ID, &session_id)
}
pub fn get_session_id(&self) -> Result<Option<u16>> {
self.get_cf(CF_META, META_SESSION_ID)
}
pub fn set_serial(&self, serial: u32) -> Result<()> {
self.put_cf(CF_META, META_SERIAL, &serial)
}
pub fn get_serial(&self) -> Result<Option<u32>> {
self.get_cf(CF_META, META_SERIAL)
}
pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> {
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default();
batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?);
batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?);
self.db.write(batch)?;
Ok(())
}
pub fn get_delta_window(&self) -> Result<Option<(u32, u32)>> {
let min: Option<u32> = self.get_cf(CF_META, META_DELTA_MIN)?;
let max: Option<u32> = self.get_cf(CF_META, META_DELTA_MAX)?;
match (min, max) {
(Some(min), Some(max)) => Ok(Some((min, max))),
(None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: delta window mismatch")),
}
}
pub fn delete_state(&self) -> Result<()> {
self.delete_cf(CF_META, META_STATE)
}
pub fn delete_serial(&self) -> Result<()> {
self.delete_cf(CF_META, META_SERIAL)
}
// ===============================
// Snapshot
// ===============================
pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
let cf_handle = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let mut batch = WriteBatch::default();
let data = serde_json::to_vec(snapshot)?;
batch.put_cf(cf_handle, b"current", data);
self.db.write(batch)?;
Ok(())
}
pub fn get_snapshot(&self) -> Result<Option<Snapshot>> {
self.get_cf(CF_SNAPSHOT, b"current")
}
pub fn delete_snapshot(&self) -> Result<()> {
self.delete_cf(CF_SNAPSHOT, b"current")
}
pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> {
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default();
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_STATE, serde_json::to_vec(state)?);
batch.put_cf(
meta_cf,
META_SESSION_ID,
serde_json::to_vec(&state.clone().session_id())?,
);
batch.put_cf(
meta_cf,
META_SERIAL,
serde_json::to_vec(&state.clone().serial())?,
);
self.db.write(batch)?;
Ok(())
}
pub fn save_snapshot_and_meta(
&self,
snapshot: &Snapshot,
session_id: u16,
serial: u32,
) -> Result<()> {
let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SESSION_ID, serde_json::to_vec(&session_id)?);
batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?);
self.db.write(batch)?;
Ok(())
}
pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> {
let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?);
self.db.write(batch)?;
Ok(())
}
pub async fn save_snapshot_and_serial_async(
self: Arc<Self>,
snapshot: Snapshot,
serial: u32,
) -> Result<()> {
let snapshot_bytes = serde_json::to_vec(&snapshot)?;
let serial_bytes = serde_json::to_vec(&serial)?;
task::spawn_blocking(move || {
let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", snapshot_bytes);
batch.put_cf(meta_cf, META_SERIAL, serial_bytes);
self.db.write(batch)?;
Ok::<_, anyhow::Error>(())
})
.await??;
Ok(())
}
pub fn load_snapshot_and_state(&self) -> Result<Option<(Snapshot, State)>> {
let snapshot: Option<Snapshot> = self.get_snapshot()?;
let state: Option<State> = self.get_state()?;
match (snapshot, state) {
(Some(snap), Some(state)) => Ok(Some((snap, state))),
(None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: snapshot and state mismatch")),
}
}
pub fn load_snapshot_and_serial(&self) -> Result<Option<(Snapshot, u32)>> {
let snapshot: Option<Snapshot> = self.get_snapshot()?;
let serial: Option<u32> = self.get_serial()?;
match (snapshot, serial) {
(Some(snap), Some(serial)) => Ok(Some((snap, serial))),
(None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: snapshot and serial mismatch")),
}
}
// ===============================
// Delta
// ===============================
pub fn save_delta(&self, delta: &Delta) -> Result<()> {
self.put_cf(CF_DELTA, &delta_key(delta.serial()), delta)
}
pub fn get_delta(&self, serial: u32) -> Result<Option<Delta>> {
self.get_cf(CF_DELTA, &delta_key(serial))
}
pub fn load_deltas_since(&self, serial: u32) -> Result<Vec<Delta>> {
let cf_handle = self
.db
.cf_handle(CF_DELTA)
.ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let start_key = delta_key(serial.wrapping_add(1));
let iter = self.db.iterator_cf(
cf_handle,
IteratorMode::From(&start_key, Direction::Forward),
);
let mut out = Vec::new();
for item in iter {
let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let parsed = delta_key_serial(key.as_ref())
.ok_or_else(|| anyhow!("Invalid delta key"))?;
if parsed <= serial {
continue;
}
let delta: Delta = serde_json::from_slice(value.as_ref())?;
out.push(delta);
}
Ok(out)
}
pub fn delete_delta(&self, serial: u32) -> Result<()> {
self.delete_cf(CF_DELTA, &delta_key(serial))
}
}

View File

@ -1,12 +1,12 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use std::fmt::Write; use std::fmt::Write;
use std::net::{Ipv4Addr, Ipv6Addr};
use serde_json::{json, Value}; use serde_json::{json, Value};
use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
use rpki::rtr::cache::SerialResult;
use rpki::rtr::payload::{Payload, RouteOrigin}; use rpki::rtr::payload::{Payload, RouteOrigin};
use rpki::rtr::pdu::{CacheResponse, EndOfDataV1, IPv4Prefix, IPv6Prefix}; use rpki::rtr::pdu::{CacheResponse, EndOfDataV1, IPv4Prefix, IPv6Prefix};
use rpki::rtr::cache::SerialResult;
pub struct RtrDebugDumper { pub struct RtrDebugDumper {
entries: Vec<Value>, entries: Vec<Value>,
@ -163,12 +163,7 @@ pub fn v4_origin(
RouteOrigin::new(prefix, max_len, asn.into()) RouteOrigin::new(prefix, max_len, asn.into())
} }
pub fn v6_origin( pub fn v6_origin(addr: Ipv6Addr, prefix_len: u8, max_len: u8, asn: u32) -> RouteOrigin {
addr: Ipv6Addr,
prefix_len: u8,
max_len: u8,
asn: u32,
) -> RouteOrigin {
let prefix = v6_prefix(addr, prefix_len); let prefix = v6_prefix(addr, prefix_len);
RouteOrigin::new(prefix, max_len, asn.into()) RouteOrigin::new(prefix, max_len, asn.into())
} }
@ -236,9 +231,8 @@ pub fn serial_result_to_string(result: &SerialResult) -> String {
match result { match result {
SerialResult::UpToDate => "UpToDate".to_string(), SerialResult::UpToDate => "UpToDate".to_string(),
SerialResult::ResetRequired => "ResetRequired".to_string(), SerialResult::ResetRequired => "ResetRequired".to_string(),
SerialResult::Deltas(deltas) => { SerialResult::Delta(delta) => {
let serials: Vec<u32> = deltas.iter().map(|d| d.serial()).collect(); format!("Delta serial={}", delta.serial())
format!("Deltas {:?}", serials)
} }
} }
} }
@ -266,12 +260,7 @@ pub fn print_snapshot_hashes(label: &str, snapshot: &rpki::rtr::cache::Snapshot)
); );
} }
pub fn test_report( pub fn test_report(name: &str, purpose: &str, input: &str, output: &str) {
name: &str,
purpose: &str,
input: &str,
output: &str,
) {
println!( println!(
"\n==================== TEST REPORT ====================\n测试名称: {}\n测试目的: {}\n\n【输入】\n{}\n【输出】\n{}\n====================================================\n", "\n==================== TEST REPORT ====================\n测试名称: {}\n测试目的: {}\n\n【输入】\n{}\n【输出】\n{}\n====================================================\n",
name, purpose, input, output name, purpose, input, output
@ -307,16 +296,14 @@ pub fn serial_result_detail_to_string(result: &rpki::rtr::cache::SerialResult) -
rpki::rtr::cache::SerialResult::ResetRequired => { rpki::rtr::cache::SerialResult::ResetRequired => {
" result: ResetRequired\n".to_string() " result: ResetRequired\n".to_string()
} }
rpki::rtr::cache::SerialResult::Deltas(deltas) => { rpki::rtr::cache::SerialResult::Delta(delta) => {
let mut out = String::new(); let mut out = String::new();
let _ = writeln!(&mut out, " result: Deltas"); let _ = writeln!(&mut out, " result: Delta");
for (idx, delta) in deltas.iter().enumerate() { let _ = writeln!(&mut out, " delta.serial: {}", delta.serial());
let _ = writeln!(&mut out, " delta[{}].serial: {}", idx, delta.serial()); let _ = writeln!(&mut out, " delta.announced:");
let _ = writeln!(&mut out, " delta[{}].announced:", idx); out.push_str(&indent_block(&payloads_to_string(delta.announced()), 4));
out.push_str(&indent_block(&payloads_to_string(delta.announced()), 4)); let _ = writeln!(&mut out, " delta.withdrawn:");
let _ = writeln!(&mut out, " delta[{}].withdrawn:", idx); out.push_str(&indent_block(&payloads_to_string(delta.withdrawn()), 4));
out.push_str(&indent_block(&payloads_to_string(delta.withdrawn()), 4));
}
out out
} }
} }

14
tests/fixtures/tls/client-bad.cnf vendored Normal file
View File

@ -0,0 +1,14 @@
[req]
distinguished_name = dn
prompt = no
req_extensions = req_ext
[dn]
CN = RTR Test Client Bad
[req_ext]
subjectAltName = @alt_names
extendedKeyUsage = clientAuth
[alt_names]
IP.1 = 127.0.0.2

20
tests/fixtures/tls/client-bad.crt vendored Normal file
View File

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDMTCCAhmgAwIBAgIUOiwVulOYZB53/kZ6zcqPl0p0g9UwDQYJKoZIhvcNAQEL
BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDUw
MFoXDTM2MDMyMDA4MDUwMFowHjEcMBoGA1UEAwwTUlRSIFRlc3QgQ2xpZW50IEJh
ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKatGSgyzw0FKy4n5pf6
W7UUJCm26WI1+W1GsAkh2EfGofD3ynE6Ne15jRjLc/PeDe0xlEl4wTQm4Gs8kkil
JhHzG4fFxUyi4MeRuQPha419hZ+pK1BHUIlURbg1oeXO5g3kwfhLPbVAQcrJdvwz
uSEGj22OmqgQFmn0Uwl6xvl/unIIzL2FXGedQA4CitaXk9dqbQWmIFDKuh0q8Kim
ePmy+8aHGuD/SeeP31ZvGFv3WcvzsX5aysRdj/WWiJUWbv1vVJ8rApswgH1qC2gY
F9lQuEMtfgo+39cUioRzD/lkzl1LnNHr/a21WGOS8ojvAiK3RC49VABVs3fRYGhc
cTECAwEAAaNoMGYwDwYDVR0RBAgwBocEfwAAAjATBgNVHSUEDDAKBggrBgEFBQcD
AjAdBgNVHQ4EFgQUWk94GyULLgNLb+38lHzbrVEVrIMwHwYDVR0jBBgwFoAUtiTC
qdhd2TwFc7JZGWkF8BEcMJwwDQYJKoZIhvcNAQELBQADggEBACHR2a6IkoHt0iOQ
DjMi45fgybvddssC1B2aDABziQDabQkrei8leEFptgCCNtMSWkmrfgZJyI2tvuws
296kDhkhIEDiwsFh24ZblYaAuKmFBEmw/v6VotFeQ9Hrsf+KKOT3jGibIb2+3ho3
q2X2e6ye7SOsfrs4hqeggcBQyzStSPoHE6KWNOHb6vBdJKavFsbff28mQIW3uEDK
j6xR7b6+xpuuuqwA41BulCTXIKGzDoIU1bet/7YmrlN69ElN97EfiT6u32dIqUN5
dwQKRlTJdf/ZxXVvEemTUcUdRDdLLyG+RMhUfiAr56G+aBn8SzMVEjMGS7bB5dsa
FhEqWj4=
-----END CERTIFICATE-----

28
tests/fixtures/tls/client-bad.key vendored Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCmrRkoMs8NBSsu
J+aX+lu1FCQptuliNfltRrAJIdhHxqHw98pxOjXteY0Yy3Pz3g3tMZRJeME0JuBr
PJJIpSYR8xuHxcVMouDHkbkD4WuNfYWfqStQR1CJVEW4NaHlzuYN5MH4Sz21QEHK
yXb8M7khBo9tjpqoEBZp9FMJesb5f7pyCMy9hVxnnUAOAorWl5PXam0FpiBQyrod
KvCopnj5svvGhxrg/0nnj99Wbxhb91nL87F+WsrEXY/1loiVFm79b1SfKwKbMIB9
agtoGBfZULhDLX4KPt/XFIqEcw/5ZM5dS5zR6/2ttVhjkvKI7wIit0QuPVQAVbN3
0WBoXHExAgMBAAECggEAEtQwKNngiQSB9cNdKeMG9CRT14CJyeX2CPG41jdEEwut
5KZhcLwWNn8KQPhO34hAw4Bb05b4IHeZ15NupRU/AT0Y6ZQeb0NhKDpej8Ep4MCC
1WALfBSqVPE3tREd+nOOiphCj1WUoYKiTBHJRsFjEweGMWawMvcqSQroTIRZsPqd
uphHGiyF5KGIPMZv6UrdC9pMllQrdYp6edu5Jk/EhX85dFv96eRwI8tutzvJ5gkt
itpX9l6reeO6s0FszcgbvzU5VxCOSV1Oo6izjdpocxnnGfO81Lc+DFy0XwJ3PCAz
ABE9SVO6TzRMcUAr9x44WffIsG6Ea9tDTHqmGdsm4QKBgQDTsOL2a+yT4bGHwkXl
GdlSNk1aLiZR/b6L1KEjsGQ77mGS+EZYouIo7BmAt3pQ/YQ/yqp/+G3RIPSqkg/w
n1iloaNAsCS7fIiOMrAKfhX6Y93Yi3orUKx+f9Y8vgTqbVXuyo4n2IvmG0Dqm6uA
YuuTvG+J3lu1Tfy6FwcMAsAYOQKBgQDJkCrg+HMkfvSykAFvMw94DEtOiWU4/bF/
r6k299i++ubx1hqk48QWuMaXPK9FnsMzCeHukoHxi4JTQvu5rxED8MLg3nKmsl78
efLcMKuSkzDtGriI4EMnAPHTcnrc7fxneUg944EOxjb2AAm0mQx95sA+q4qFvEy6
Yq06tIJwuQKBgB6o+/Zc40L36VKUXLM17zftDX8GOB6f9b0i6sPUhG/5ssAqnWWx
EbiDmZ3+9QRN852ZqOAoBx/G+ijKRuy+54P1yUNRP8C35L9TsBOU93HwjO0UJnmn
kZQwx8K8ctHRTCTtyXET+A8320sfiNNrgFtBa5Y4UmgMB5KcSzT+IPxhAoGAEGFk
+q92PAsNO820MCNIKItnO1SzIzSKzkOqTstJlAuz5QdvVuMjtm0Bxpyp6dCDMIyn
DcpeQREDYFzbNDXj/hv82mV5j86DJaWLdRWHe/v2R+6Z/JWtH2hWPsbY8Udt8cLL
eiwY+uhk4w0RvNmLSFgOW4l5UnEBE0ydo120FBECgYEAxTNChdhMIHnFQQHLYI+6
SR6LqnjPwqdCQSzy4NrnGJ9O7c+Sed1L5RVg88OKG2dZo+GGjpsSAFRv3gB6rOX9
bYxklbPWC65bbWcIMNK2LCJ6+LahZzBwz2tBZOA5GBs0DunQ0PdqLnxdi7QZoHfb
pAuCQadp+4BQHVsDwIwpUH4=
-----END PRIVATE KEY-----

19
tests/fixtures/tls/client-ca.crt vendored Normal file
View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDGzCCAgOgAwIBAgIUKCaM+eEU12YVn/tC6MEBYPj50oIwDQYJKoZIhvcNAQEL
BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDQy
OFoXDTM2MDMyMDA4MDQyOFowHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENB
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmmS+gsRZXPaG20Ra2te7
WVeqyeFX2lV/RvNsT3WW+EtwuTN2aFJ/pBgSCqgCaFocU5IcgfScybQrjmTWeEot
D0F1OOW/TihlBsrU3BS8vOZ4KTuwB9eOgpt7FL8CxLetjNWxZKBsvJMauOUYFLCt
DN+78EDqrZ2VyJPfa+7dJKS/JoRgs/WFjvAeb5x1UpKPf/gp1Ji+Vj92Ux9znYF1
2Zfs77NsLdwcqyMVj7cb1nucutNf/vMqNi+BoSwC/tQkflda9bLUwto/ZfFPsCFw
Ptjwiwm+IGNAKG8r1ujaUVDtxc+NliuAbHERBgZGtYpAE6u28+d8cUAyGGQFNGqb
uQIDAQABo1MwUTAdBgNVHQ4EFgQUtiTCqdhd2TwFc7JZGWkF8BEcMJwwHwYDVR0j
BBgwFoAUtiTCqdhd2TwFc7JZGWkF8BEcMJwwDwYDVR0TAQH/BAUwAwEB/zANBgkq
hkiG9w0BAQsFAAOCAQEATr0wChmDvmw51RQ0kOFP1l3n+O1n/U0xR39YApnmdhtk
MyO9gNyqxdUWj4QaRkswr5iMyjsLrT3lYWGfh7oHlhIgkXEy5OK548rhj/PA6d5U
M7DDGI06EFfYXcC57Cx1y8Egd9gfkfPk0ned/TG+/dYtyb7sBiIKrOiDExbfZ53U
mNtfsdptBoOjee4KggplfVhkUopyO0KKD7twq4dvWgzWZHv7m5wrQA8GP8W4JpPP
ZAs8cOLfq4cxAwYFXEaf0x7is2u9KdagPqXkiPMeC8fOHWyRCR7sKN4Dnh57ruqq
mQXil9GGFfkDA/bh6+Rs93Yll7yYyA3xif6fSBmgUg==
-----END CERTIFICATE-----

28
tests/fixtures/tls/client-ca.key vendored Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCaZL6CxFlc9obb
RFra17tZV6rJ4VfaVX9G82xPdZb4S3C5M3ZoUn+kGBIKqAJoWhxTkhyB9JzJtCuO
ZNZ4Si0PQXU45b9OKGUGytTcFLy85ngpO7AH146Cm3sUvwLEt62M1bFkoGy8kxq4
5RgUsK0M37vwQOqtnZXIk99r7t0kpL8mhGCz9YWO8B5vnHVSko9/+CnUmL5WP3ZT
H3OdgXXZl+zvs2wt3ByrIxWPtxvWe5y601/+8yo2L4GhLAL+1CR+V1r1stTC2j9l
8U+wIXA+2PCLCb4gY0AobyvW6NpRUO3Fz42WK4BscREGBka1ikATq7bz53xxQDIY
ZAU0apu5AgMBAAECggEABjSKW6McnFnkLaffpvAIvZyCZr7B0yqghO9/qOnm+W++
xhLFbYfzTVsSTo9WGW+Vt94leyujqY+uOHjhDdCdYwGUfobtW2zQMqewSnAi7cyt
g6q8dnQ5bBJnrfvHVrSzKvfju1GfTSz0Y/4BK8O2ENBlM1DIndW5kWgwEJx3EuPk
KzsjeZnWEUWKKcHL6/+jviTgfM7tnMjgtHMeHU/v8hVKBtebWAxvPjTE+HRoWePE
uEgvGC9eIHb7fRh0sEt4+vrx5GkKikwj4278yhTJTF4X4kTE8T1kBYaihAiSTyoy
jdf7DkGtfC//3k6yc3mSabq7UjMLu1m4HGCSXjs2iQKBgQDW/e+Qx7P/hyeya/gQ
JDAsZciBKxBn8gGywRbtEzFz08Wb1WMX0LaioUIVMXUmVzXWgQkaiztzS+AQTr+D
sRpbOqkxalUdzyaHzgUTKhyvU2dzPOWWFUwJvBPKShhUibukgl+BIIdXr701qKf5
W/sL3e7CMxNfirKfMoXe4q4m1wKBgQC318k5+63TexbP/peY6A+bGudJo/PKE4aD
GexJEvUcEV1J5d3SKXZoe2V0fTmFsvdhJIPFg2rnVRy2tV6F1vFzHAiou3tc0skU
vlM8nTDwPI80kk44LRMzzqZoe5pziI3tUS/Qwu8FX3MfXkkKd0wnQRhoUc0fMdT2
zOuvFehP7wKBgBQNsafeiNKf57sDySqwRXIOuGob+zbG4xOqYRoR/T3hlgAYIlsZ
U7/NrN1PNK8z2Ui91nyMWipB/I9o2QJOpbe2vAto8LGMHfry45RLDEvqSq78Eioy
qFoMGgh3ateP1Vnd80yXHSi3sr1rkud2he8wb1Hb88WoqUqiKsyEdlwXAoGBAIdN
V/rFoRv5BkQUAqx1di7YMQrAkIbTsfbA2Ga7fgunN/pQI94tx8iDsJp4IyKkIW6s
OhLecopI2LYba7KjC9aE9laAjP024OjUXlxI8CCO4XJ2jvzHJ8/EMjLJbVXEVXgo
fUFuhg11PzwB303FmRV20ijMs2NXAH6XOIoGXJCfAoGALCJ11wdb1m2vzYRLlSqo
mcDQ1PltSbnFZ6KyuJM5MnJ/TIwnUQ7Rnsjya5m4HFJtJnPSgrygckjrZwPLZpPz
TrKr614Ln4E3YU2IoVTyjjgtkEHn0fRcxvsn9z5vXCrzM4JJF5Ac+nbKEwgugmpu
JjMMsfGPFC/cr+SpRZr/Nf8=
-----END PRIVATE KEY-----

14
tests/fixtures/tls/client-good.cnf vendored Normal file
View File

@ -0,0 +1,14 @@
[req]
distinguished_name = dn
prompt = no
req_extensions = req_ext
[dn]
CN = RTR Test Client Good
[req_ext]
subjectAltName = @alt_names
extendedKeyUsage = clientAuth
[alt_names]
IP.1 = 127.0.0.1

20
tests/fixtures/tls/client-good.crt vendored Normal file
View File

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDMjCCAhqgAwIBAgIUTLfIortKKjsYovPDuDVnUR8i2YMwDQYJKoZIhvcNAQEL
BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDUw
MFoXDTM2MDMyMDA4MDUwMFowHzEdMBsGA1UEAwwUUlRSIFRlc3QgQ2xpZW50IEdv
b2QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUjnhxQxKUNLjGN6mu
/VvoGXfsyvZq38QMlfIauGPQ5Jwpq2BU2gW1zFvjBJaS8ZKY2DOsBKvZMrEiSxVt
8WCCgvSucw5Elwniq/Yu8UmkvWEMnk9oXwQuHQeWYSe2eMvwBMTNaj3Vz+xo2Ga/
+eGClxPDcKKvOMYz+BRpyMfHzX1LLrM26yDjyxQKxzsrFMQ5IKXkhv4COLAjpLZd
oEDSkG+W4fbGTC6bSkskFP62OmnR/Qq6dXBVAJkzHqoiWEmKWBQy9JhEo+N3Jz7D
QfOAI/20MD6J4brWS56ciHL67yQjHJW6P0HlcLzefFqspIwvTFyCsF137jJi9sqL
y7UlAgMBAAGjaDBmMA8GA1UdEQQIMAaHBH8AAAEwEwYDVR0lBAwwCgYIKwYBBQUH
AwIwHQYDVR0OBBYEFJy0O+AR5EdWrUr2UHtPCUVTrRiEMB8GA1UdIwQYMBaAFLYk
wqnYXdk8BXOyWRlpBfARHDCcMA0GCSqGSIb3DQEBCwUAA4IBAQB8eavz/Pk3oJBh
sgy4ve6nLAqKGcD+zSoBxw/ErZaDNdUZZaFKT9nypfi39jkfjU+CIfhczCs5Cknk
EDqYtp1BXkY+auKeYRkoaCKv+ucnIc6JZ70NOQGdDNzU4eb9tVA9Py0j5VzvxWR1
yl6vHTc3tTA873RRezvUe+SQDwR+x2dxZ0O0MFZR4CTaTeR+fpD4bE9mBcYN1KVk
HdA7USCwOZ0qE+aHTegM4pgGrpHYS5yBpUFADqfJl1yBnnqc10zkphoy6dRsbVuN
NkDostGbbfGdGjEPs6i1TMsDTVyOo4nsHACk4q4m/IEVBkiMwxmwyhK1EEWQmQpy
h0CqAs3Z
-----END CERTIFICATE-----

28
tests/fixtures/tls/client-good.key vendored Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDUjnhxQxKUNLjG
N6mu/VvoGXfsyvZq38QMlfIauGPQ5Jwpq2BU2gW1zFvjBJaS8ZKY2DOsBKvZMrEi
SxVt8WCCgvSucw5Elwniq/Yu8UmkvWEMnk9oXwQuHQeWYSe2eMvwBMTNaj3Vz+xo
2Ga/+eGClxPDcKKvOMYz+BRpyMfHzX1LLrM26yDjyxQKxzsrFMQ5IKXkhv4COLAj
pLZdoEDSkG+W4fbGTC6bSkskFP62OmnR/Qq6dXBVAJkzHqoiWEmKWBQy9JhEo+N3
Jz7DQfOAI/20MD6J4brWS56ciHL67yQjHJW6P0HlcLzefFqspIwvTFyCsF137jJi
9sqLy7UlAgMBAAECggEALmsOkm17WTJKR79QJw7dS0qEjgmk1qIXRkhYns01vyCt
mcv7NYyHQrRmPKV73Is04HwWjLJYdQ5E8KBFBcV4tgezN4WY0BHL7txu3sGCu58/
2mmYHcriNs/QIF8HNSocH0ZrVBCngFHv5tWbWsFPJh2oCz5FyM41OpQqoQ9f0ZoA
PKq3m7dNidFmlQPhhJ9KeLYwDfzTRy4sWBDrgZFDQ6ut9nO8/JpJvxZCSnOdbW1T
6pYk3uAKKcmrcdI2fsg7u0J1zkrkDnRdrDv1QCY/h1UXvR4KuMeYRylHWJpOKUiR
asxpJO+lh7TmDaEIqO+dSmTHawcKDYuIp0EV4/x7AQKBgQD/cgMhx97bJ3ym0ucc
xtt10jo0NTW7SJIXx/eUdBQNA3aGWfsqLrBI5qKrXLxTXF/vAs3LWHiHF6EgRxW/
7RbO+NZzAGxg53Y3PMqZ4AQbOm2JCIgT9ZhQc55+qowdtrts3njL+ayy3vWh3nPt
G78+rs4ZLCHhEutbkKqU7S3XSQKBgQDVBJ5l9d1QcqFqYnlmyhgfwJwpxo4N2dJR
18pCow8Mp8S9ROJx3OF4yNTRa6ipnX0oGdgchT63jXBDoKot3/4tqbjYj/Y3MmKL
crZ7dZEiAqkqE975qUfSlQS6BtknRKduzQzAJAKz0A+JbPIisPHIrBnG2PH9h2XA
TjA2rHRi/QKBgDIp65+IpqUW/g2swSIPky1yGWgDQwgCWl49MMuAeCeOFIqRxRcl
kAzg7fUFAx7Dtzsyq8NRHmo5I7U5AHZuUtpWV5bB8IafLcHvOEI7kdLfCH+uozp4
Mm8qJWfuihGTvv7EOaik4VtHGamuC8n2dvoSTfr3hbezhXC32ifg4+2xAoGATbjp
snoKzheFHbPgZ8jFFJDKadOwcQ1Q19vMSJQGIa/08Ln5hWH6Qn/EZsTJPVnhGIiV
eZKEV6SbmZE9ho97xl1uvFWKmIkhu4+XVWSIF8iwwFGPwbgqJIOKvfVRtioujRbz
2AdLlSANCy9dCZtWHMnufccaRE7qqUfd/5TcwmECgYA5md1JZ+p+0yTkjNf4McCe
mpZiXXaM8uC3/XiWlfcw3ag88H0XCqt2UPcX03MC9ZzywfTdoEHhPRfr48uRiH7K
Pi1T8E2E0GkSfk6zh6wJm8Kmm9+VJd2rUFrb35xDMUfjpS228oH9UBQtwaAV+jmy
q/APoG+aWt5YIeYElx9mzA==
-----END PRIVATE KEY-----

14
tests/fixtures/tls/server-dns.cnf vendored Normal file
View File

@ -0,0 +1,14 @@
[req]
distinguished_name = dn
prompt = no
req_extensions = req_ext
[dn]
CN = localhost
[req_ext]
subjectAltName = @alt_names
extendedKeyUsage = serverAuth
[alt_names]
DNS.1 = localhost

19
tests/fixtures/tls/server-dns.crt vendored Normal file
View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDLDCCAhSgAwIBAgIUU9TJAAXI1pavsowFKFye0t8gdz4wDQYJKoZIhvcNAQEL
BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyNDAyMzA0
N1oXDTM2MDMyMTAyMzA0N1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkq
hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzXLZ2ZeySuuzCDEPXsHMym0py7Tuwizu
5a4AKQt+7d8b51x22W+I6ZQXYOCUml+C0PGs5TfQDXs0VbYj08flVxw9732FPqzp
PhkZJlEbnu7+nEW4mb+aM+pnPnW7iLeCvePGXicYO9YctQNA9Onxe+63QtEi2OtO
24wlqreToA2qQ7UOeNbpcY5p9psgXt9wW+8rq1gXmKsci8DV//1MessDTTX4tgIh
lOvKBIqx2E2kIIbZ6BwB7veKCkwqrFeVKzmwCXfkVQZkyfxKTBOvf3VFssnFJ7Ld
dYuAMTpl9jCHkNykGNlc0a9ohglEvAjqwgpUdMcuzMMR8Kf0DjqZqwIDAQABo20w
azAUBgNVHREEDTALgglsb2NhbGhvc3QwEwYDVR0lBAwwCgYIKwYBBQUHAwEwHQYD
VR0OBBYEFOYSmMCX8XUYA9jMGY2M1v5/xaimMB8GA1UdIwQYMBaAFLYkwqnYXdk8
BXOyWRlpBfARHDCcMA0GCSqGSIb3DQEBCwUAA4IBAQArYXsdMl+vPyTyv/oRzCo0
mWyVO3RZPPRJSrrMz4UjtVMIfq6pnV+AXPEcWL/zEfQNcRPjsJeezkG66AZbI+ug
fur90YVDhWfOgde4E3cVhZz90aM/jcRMVvwNj0XiaX4JpxsVhv5T4LC80aXm0r02
YfJyqwtNaZlKDntaW56q7nD5eaoqmYa+ogpdqwCIfGvManfRH6v6xmzgQRKnD1lc
LZPDZ9dmkQg2N/vdZfVUpB2+EZYF/9BuIvJyGKNYBjJiQGv2kbaFUl13mw7D/yGP
Zytpu5AIPl6ScYog8x6dSJMiYM+hO1bD9qOU0Kq48PaLQoQ4poYo7zdbiB8TZaKC
-----END CERTIFICATE-----

28
tests/fixtures/tls/server-dns.key vendored Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDNctnZl7JK67MI
MQ9ewczKbSnLtO7CLO7lrgApC37t3xvnXHbZb4jplBdg4JSaX4LQ8azlN9ANezRV
tiPTx+VXHD3vfYU+rOk+GRkmURue7v6cRbiZv5oz6mc+dbuIt4K948ZeJxg71hy1
A0D06fF77rdC0SLY607bjCWqt5OgDapDtQ541ulxjmn2myBe33Bb7yurWBeYqxyL
wNX//Ux6ywNNNfi2AiGU68oEirHYTaQghtnoHAHu94oKTCqsV5UrObAJd+RVBmTJ
/EpME69/dUWyycUnst11i4AxOmX2MIeQ3KQY2VzRr2iGCUS8COrCClR0xy7MwxHw
p/QOOpmrAgMBAAECggEAKOWyBCS0c0GUNA9AFgbSM4Gjjk+IL9MuAz6t/X2yWLvP
HDFF32bahFTcioZqToiwy9MwLbj8i5J5Co7lduUV/E887Q25lo5pJ9lrLjt7GhpN
SOKAKur/UVJaFw7ss/yD5DURafCyXEb1E/t/1ME1NwyAIqbrHu9IlV/Cp8c/dd6n
BUZHKN38Fs8riUOTW4BwAA1dV0k+g8bKb0byR5WmXbXqWG/XjfUKcI4Dg9ILsTwG
QgI/nzRyAV2unVcdNhM4pMTJ4U4x1Dh/rsWyQ4b206namkpfaIH7j8ovRY2OHhzT
EuCbn1Cpfmxp68hRrvMIhLvqVkXOSWzKOkHZba/qYQKBgQD0m3WJu5UHXip14gY/
SmDm3vEcFjVZa+BmQtojs32ZfT/b1H5U+mfhXlVi0yf7B7nitQ9zOA7A2C8cxdtr
X9zd81MHLc7yUEOAlHBqSfBX/DyF1iRsdM7GYYxrsaECM7NIDukDR1Q/DvMZodRx
iVpRkNP3s87a1/QwvDKrwOPcCQKBgQDXBH1IZWcl1mrGgfAI7r5wXIJo0IrTlUcT
4nFSNv5YnrhnRzmLDRqxwlUs4II87nzsqv7OYryAg4n8us+X6XQv0hvTKcHsM7tl
vAO2EeQIER71tEg45HhmAXo2X2bE5kNMTr4COlb8k6R6r5K2F4yVMZim4qPcMYAK
tNRXf9JdEwKBgQDEjpOxvPmxdPrbxWfNvgAGJYpMTpBKLgShSAEwhRBdoacKCEQI
FzwYfoxQoGtVLk0yHtqudJJuZondLiT2sI60D85dS3MrhlHn5eA7mPS4Tyl3Rq/4
Mxjhkwuakp9WPKNJOSoHB29sSKASrdcf8QaR2rZqKqQDeVtxOhnhqFuxuQKBgHrk
8/53BteXj/vZtKpGWs658UebOl3oinGREZgeGo3oWhmdmgQh/0nueuRlhcrxvLFA
otavlHIXvLyYwaJgKqpSetjcmxw4DTn+llhwLVd3Aa0J1+W8oBwdaA6/xGtx+LEa
qHt5gNJoSLBevYoaN53mdQudqm5mVHrKFDvWsRPFAoGBALAkJiXM4y+aTcwREjnO
CgebDRru37mprDW2PDa4vpTydxIcXZukTEZqXbVnXNDY/Gkyv06jYoY2R7mzbgfl
7RJIFAQffkzoeFadQZqimIAN3sQ6RRzNqmAlS2D6zjylX7N2rDiPcS3LtrJZHEyk
gTFx7+6gjfPNqPiAB5su0IJm
-----END PRIVATE KEY-----

14
tests/fixtures/tls/server.cnf vendored Normal file
View File

@ -0,0 +1,14 @@
[req]
distinguished_name = dn
prompt = no
req_extensions = req_ext
[dn]
CN = RTR Test Server
[req_ext]
subjectAltName = @alt_names
extendedKeyUsage = serverAuth
[alt_names]
IP.1 = 127.0.0.1

20
tests/fixtures/tls/server.crt vendored Normal file
View File

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDLTCCAhWgAwIBAgIUOiwVulOYZB53/kZ6zcqPl0p0g9QwDQYJKoZIhvcNAQEL
BQAwHTEbMBkGA1UEAwwSUlRSIFRlc3QgQ2xpZW50IENBMB4XDTI2MDMyMzA4MDUw
MFoXDTM2MDMyMDA4MDUwMFowGjEYMBYGA1UEAwwPUlRSIFRlc3QgU2VydmVyMIIB
IjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAn1qUW0bUWsq9pMNAzukncqyb
3sTbVcO+KfZfzBRwFfbWP1HmUziIsBq01b1m6Py6CL/DWBjgw4lJ8ZgxlhbjB9Zu
DVCietURIZIGMjuBiKOGW8vkCd3hlZqNph5sp2rUT2RBKym+orIApkt75epZB1cW
0KGsgnAhRpY6QAp4xalggtVp3HvUqapnVkRgix9AF3EAJJml+8QhGjgAn5tzFlOU
s5PjHeCDawTgfsAcoyRLi2E/OWAqJoEjujIPrkzfTC8d21yavQvIVeP+OyBRBwrw
kNZTyYy5LjO4tqZr8yLKdcMtIdAK2+7lEwWpR/LjFTYXwGs5GEb1Obon1ZdxEQID
AQABo2gwZjAPBgNVHREECDAGhwR/AAABMBMGA1UdJQQMMAoGCCsGAQUFBwMBMB0G
A1UdDgQWBBQQ7EAHYbMAQW6UdIcRjUmpMOyboTAfBgNVHSMEGDAWgBS2JMKp2F3Z
PAVzslkZaQXwERwwnDANBgkqhkiG9w0BAQsFAAOCAQEAffuJO/uKj23/3uFCt06+
JT/8lozBSCifvileFjMZmNHViWT3cZuD2NM1/0uFNf9k95yyFn8yVeS4RiobX75L
rcqILZUesF5bkD7lmPdazR3Cz0BcCGo9+DvgWaX35lBOsuIMXRG3aUs6/x939zyc
p/KtDjqAgd26nhefJlZeb+Z9UYdmMXzcD5n8doWmNgYeaVQraBvMazenSmcwLzn6
qzHjW5osUILBSl36KxuChHjHSa9aFPIrhoiGMR0w9oOGm6cHE3R4seNJNAX+YoZq
9PXgvVAHWvoSUuATY8iySNRiWUOmpybQuw3iqfgGighekZfxukLMSx2jq4WIquDg
Aw==
-----END CERTIFICATE-----

28
tests/fixtures/tls/server.key vendored Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCfWpRbRtRayr2k
w0DO6SdyrJvexNtVw74p9l/MFHAV9tY/UeZTOIiwGrTVvWbo/LoIv8NYGODDiUnx
mDGWFuMH1m4NUKJ61REhkgYyO4GIo4Zby+QJ3eGVmo2mHmynatRPZEErKb6isgCm
S3vl6lkHVxbQoayCcCFGljpACnjFqWCC1Wnce9SpqmdWRGCLH0AXcQAkmaX7xCEa
OACfm3MWU5Szk+Md4INrBOB+wByjJEuLYT85YComgSO6Mg+uTN9MLx3bXJq9C8hV
4/47IFEHCvCQ1lPJjLkuM7i2pmvzIsp1wy0h0Arb7uUTBalH8uMVNhfAazkYRvU5
uifVl3ERAgMBAAECggEADNbvCjW4SYJ5akujLqjBmW9H9diVtaDacyYbTOW+rD5h
v+NY9A6jkNDuDiS/JHmsgaT11+TVQ1wN/a3eHPJGI60G3ALJvKzrPvG1lxmNU9Wd
L0tL2fGrSfMUg7SC27BzX9w7lf88kX5XKA7/8iQCPWGqgG/uZuoi/D2BfwR4+7AL
0jPnElGK2PAHBWcEWvAQVfBRxPB96trF58+FLE+blQff2LY3AOR5pVRVwdK6V148
4bT6tlPFcUabyVI70mqpw1FjoocEwMb8Fd7uOco5/0FulG3qe9DpZAjM2MWc4LRH
K3ZSI/QUp7CGEPcPeah3I/XHdqNrzRwKYj1nM3csoQKBgQDf/1NarvGaJNexjYHn
TZOZ6AOZZsP4MqdcNqNzV5K/Gmy1kZ5r7e5eH+SuZ1qH9jNLgNlusGWjmmrivYJo
qoiVDzE1tO5m3r08w/5187+TWuhUSOgObKFjtCS+GSR+YbGp/sxoWc5uXm0UZune
mWjtPRVwEXBCyrmIAGCGJGYMmQKBgQC2HuzGeJT1e+JxpgvDZKYU9mg0MDYQhejV
ibOUfSedSdw2P1vxQWAlQSPHbVpftMqXeM48b/B6IkzvwwSVNxV2I1agxFj0ZLTx
MQhDhet0CCFC0UBT9/Ch6JWQ4FbrM5OaNhLVcfeksc8B7cTYZB/lMeSE2r1GH8t1
HNqj8+6bOQKBgGwOYKiLYlOI2GB3siXh34VMTogu8fSGgwPR+9GFem4kEjMY10Kb
mfTgD9IuW5bhJueSddGW2MEumcddwk45jf/SP1v4N1V6t/FbXyKJfm5YWWFndkKX
FtfhLCRkPp2VBT7LgtIIGLRXaul/p+xRNzPS1sekMfKWlx/LhsTPREdpAoGBAKZY
UMRnVwdyBD7x/0SVJe13s24Xqwokhaqlf9VdC1XrJKyX6o7Nu9fLS7bX9vf71h/M
Q/OH+wpTUhqc8g6opX2mgXWOYgG4Cl1S/81NAOaWlmrFXhBUIwJ/wjz16+4gyezM
/x7eXeecUQvd9TIBIfDiRWvjr4XhfKCXnkyqfYJhAoGAGMUedwQLN11NLclbTiVo
+xgy1yMeZV8c49+mocwwLgoxINBwygSU3klZZ4Bg+whzaVVlHccrHtEhw+73CbiM
xmfIuD0C/U/0FIP2kPm/g8rtH1QVE6Rkshwzzf+txeaGSlXvVtq4zjlVbDi+UrYk
qhSRT6NNKMROWj3Mx3CIV/A=
-----END PRIVATE KEY-----

File diff suppressed because it is too large Load Diff

204
tests/test_pdu.rs Normal file
View File

@ -0,0 +1,204 @@
use std::net::Ipv4Addr;
use tokio::io::{duplex, AsyncWriteExt};
use rpki::data_model::resources::as_resources::Asn;
use rpki::rtr::error_type::ErrorCode;
use rpki::rtr::payload::{Aspa as PayloadAspa, Ski, Timing};
use rpki::rtr::pdu::{
Aspa, EndOfDataV1, ErrorReport, Flags, Header, IPv4Prefix, RouterKey, SerialNotify,
END_OF_DATA_V1_LEN, MAX_PDU_LEN,
};
const ERROR_REPORT_FIXED_PART_LEN: usize = 16;
#[tokio::test]
async fn serial_notify_roundtrip() {
let (mut client, mut server) = duplex(1024);
let original = SerialNotify::new(1, 42, 100);
tokio::spawn(async move {
original.write(&mut client).await.unwrap();
});
let decoded = SerialNotify::read(&mut server).await.unwrap();
assert_eq!(decoded.version(), 1);
assert_eq!(decoded.session_id(), 42);
assert_eq!(decoded.serial_number(), 100);
}
#[tokio::test]
async fn ipv4_prefix_roundtrip() {
let (mut client, mut server) = duplex(1024);
let prefix = IPv4Prefix::new(
1,
Flags::new(1),
24,
24,
Ipv4Addr::new(192, 168, 0, 0),
65000u32.into(),
);
tokio::spawn(async move {
prefix.write(&mut client).await.unwrap();
});
let decoded = IPv4Prefix::read(&mut server).await.unwrap();
assert_eq!(decoded.prefix_len(), 24);
assert_eq!(decoded.max_len(), 24);
assert_eq!(decoded.prefix(), Ipv4Addr::new(192, 168, 0, 0));
assert!(decoded.flag().is_announce());
}
#[test]
fn error_report_truncates_large_erroneous_pdu() {
let pdu = vec![0xAA; MAX_PDU_LEN as usize];
let text = b"details";
let report = ErrorReport::new(1, ErrorCode::CorruptData.as_u16(), &pdu, text);
assert_eq!(report.as_ref().len(), MAX_PDU_LEN as usize);
assert_eq!(
report.erroneous_pdu(),
&pdu[..(MAX_PDU_LEN as usize - ERROR_REPORT_FIXED_PART_LEN)]
);
assert!(report.text().is_empty());
}
#[test]
fn error_report_truncates_text_to_fit() {
let pdu = [1, 2, 3, 4];
let text = vec![b'x'; MAX_PDU_LEN as usize];
let report = ErrorReport::new(1, ErrorCode::CorruptData.as_u16(), pdu, &text);
assert_eq!(report.erroneous_pdu(), pdu);
assert_eq!(report.as_ref().len(), MAX_PDU_LEN as usize);
assert_eq!(
report.text().len(),
MAX_PDU_LEN as usize - ERROR_REPORT_FIXED_PART_LEN - pdu.len()
);
}
#[tokio::test]
async fn error_report_rejects_non_utf8_text() {
let (mut client, mut server) = duplex(1024);
let header = Header::new(1, ErrorReport::PDU, ErrorCode::CorruptData.as_u16(), 17);
let mut bytes = Vec::from(header.as_ref());
bytes.extend_from_slice(&0u32.to_be_bytes());
bytes.extend_from_slice(&1u32.to_be_bytes());
bytes.push(0xFF);
client.write_all(&bytes).await.unwrap();
let err = ErrorReport::read(&mut server).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn router_key_length_matches_wire_size() {
let ski = Ski::default();
let spki: std::sync::Arc<[u8]> = std::sync::Arc::from(vec![1u8; 32]);
let pdu = RouterKey::new(1, Flags::new(1), ski, Asn::from(64496u32), spki);
let (mut client, mut server) = duplex(1024);
tokio::spawn(async move {
pdu.write(&mut client).await.unwrap();
});
let header = Header::read(&mut server).await.unwrap();
assert_eq!(header.pdu(), RouterKey::PDU);
assert_eq!(header.length(), 8 + 20 + 4 + 32);
}
#[tokio::test]
async fn aspa_length_matches_wire_size() {
let pdu = Aspa::new(2, Flags::new(1), 64496, vec![64497, 64498]);
let (mut client, mut server) = duplex(1024);
tokio::spawn(async move {
pdu.write(&mut client).await.unwrap();
});
let header = Header::read(&mut server).await.unwrap();
assert_eq!(header.pdu(), Aspa::PDU);
assert_eq!(header.length(), 8 + 4 + 8);
}
#[test]
fn aspa_announcement_rejects_empty_provider_list() {
let err = PayloadAspa::new(Asn::from(64496u32), vec![])
.validate_announcement()
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("at least one provider"));
}
#[test]
fn aspa_announcement_rejects_as0() {
let err = PayloadAspa::new(Asn::from(0u32), vec![Asn::from(64497u32)])
.validate_announcement()
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("AS0"));
let err = PayloadAspa::new(Asn::from(64496u32), vec![Asn::from(0u32)])
.validate_announcement()
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("AS0"));
}
#[test]
fn timing_rejects_out_of_range_refresh() {
let err = Timing::new(0, 600, 7200).validate().unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("refresh interval"));
}
#[test]
fn timing_rejects_expire_not_greater_than_retry_and_refresh() {
let err = Timing::new(600, 600, 600).validate().unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("expire interval"));
}
#[test]
fn end_of_data_v1_rejects_invalid_timing() {
let err = EndOfDataV1::new(1, 42, 100, Timing::new(600, 8000, 7200)).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("retry interval"));
}
#[tokio::test]
async fn end_of_data_v1_read_rejects_invalid_timing() {
let (mut client, mut server) = duplex(1024);
let header = Header::new(1, EndOfDataV1::PDU, 42, END_OF_DATA_V1_LEN);
let mut bytes = Vec::from(header.as_ref());
bytes.extend_from_slice(&100u32.to_be_bytes());
bytes.extend_from_slice(&600u32.to_be_bytes());
bytes.extend_from_slice(&8000u32.to_be_bytes());
bytes.extend_from_slice(&7200u32.to_be_bytes());
client.write_all(&bytes).await.unwrap();
let err = EndOfDataV1::read(&mut server).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("retry interval"));
}
#[tokio::test]
async fn end_of_data_v1_read_payload_rejects_invalid_timing() {
let (mut client, mut server) = duplex(1024);
let header = Header::new(1, EndOfDataV1::PDU, 42, END_OF_DATA_V1_LEN);
client.write_all(&100u32.to_be_bytes()).await.unwrap();
client.write_all(&600u32.to_be_bytes()).await.unwrap();
client.write_all(&600u32.to_be_bytes()).await.unwrap();
client.write_all(&600u32.to_be_bytes()).await.unwrap();
let err = EndOfDataV1::read_payload(header, &mut server).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("expire interval"));
}

File diff suppressed because it is too large Load Diff

View File

@ -6,9 +6,9 @@ use common::test_helper::{
indent_block, payloads_to_string, test_report, v4_origin, v6_origin, indent_block, payloads_to_string, test_report, v4_origin, v6_origin,
}; };
use rpki::rtr::cache::{Delta, Snapshot}; use rpki::rtr::cache::{CacheAvailability, Delta, SessionIds, Snapshot};
use rpki::rtr::payload::Payload; use rpki::rtr::payload::Payload;
use rpki::rtr::store_db::RtrStore; use rpki::rtr::store::RtrStore;
fn snapshot_to_string(snapshot: &Snapshot) -> String { fn snapshot_to_string(snapshot: &Snapshot) -> String {
let payloads = snapshot.payloads_for_rtr(); let payloads = snapshot.payloads_for_rtr();
@ -69,37 +69,52 @@ fn store_db_save_and_get_snapshot() {
fn store_db_set_and_get_meta_fields() { fn store_db_set_and_get_meta_fields() {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap(); let store = RtrStore::open(dir.path()).unwrap();
let session_ids = SessionIds::from_array([40, 41, 42]);
store.set_session_id(42).unwrap(); store.set_session_ids(&session_ids).unwrap();
store.set_serial(100).unwrap(); store.set_serial(100).unwrap();
store.set_delta_window(101, 110).unwrap(); store.set_delta_window(101, 110).unwrap();
let session_id = store.get_session_id().unwrap(); let loaded_session_ids = store.get_session_ids().unwrap();
let serial = store.get_serial().unwrap(); let serial = store.get_serial().unwrap();
let window = store.get_delta_window().unwrap(); let window = store.get_delta_window().unwrap();
let input = format!( let input = format!(
"db_path: {}\nset_session_id=42\nset_serial=100\nset_delta_window=(101, 110)\n", "db_path: {}\nset_session_ids={:?}\nset_serial=100\nset_delta_window=(101, 110)\n",
dir.path().display(), dir.path().display(),
session_ids,
); );
let output = format!( let output = format!(
"get_session_id: {:?}\nget_serial: {:?}\nget_delta_window: {:?}\n", "get_session_ids: {:?}\nget_serial: {:?}\nget_delta_window: {:?}\n",
session_id, serial, window, loaded_session_ids, serial, window,
); );
test_report( test_report(
"store_db_set_and_get_meta_fields", "store_db_set_and_get_meta_fields",
"验证 session_id / serial / delta_window 能正确写入并读回。", "验证 session_ids / serial / delta_window 能正确写入并读回。",
&input, &input,
&output, &output,
); );
assert_eq!(session_id, Some(42)); assert_eq!(loaded_session_ids, Some(session_ids));
assert_eq!(serial, Some(100)); assert_eq!(serial, Some(100));
assert_eq!(window, Some((101, 110))); assert_eq!(window, Some((101, 110)));
} }
#[test]
fn store_db_clear_delta_window_removes_both_bounds() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
store.set_delta_window(101, 110).unwrap();
assert_eq!(store.get_delta_window().unwrap(), Some((101, 110)));
store.clear_delta_window().unwrap();
assert_eq!(store.get_delta_window().unwrap(), None);
}
#[test] #[test]
fn store_db_save_and_get_delta() { fn store_db_save_and_get_delta() {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
@ -194,43 +209,220 @@ fn store_db_load_deltas_since_returns_only_newer_deltas_in_order() {
fn store_db_save_snapshot_and_meta_writes_all_fields() { fn store_db_save_snapshot_and_meta_writes_all_fields() {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap(); let store = RtrStore::open(dir.path()).unwrap();
let session_ids = SessionIds::from_array([40, 41, 42]);
let snapshot = Snapshot::from_payloads(vec![ let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)), Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)),
]); ]);
store.save_snapshot_and_meta(&snapshot, 42, 100).unwrap(); store
.save_snapshot_and_meta(&snapshot, &session_ids, 100)
.unwrap();
let loaded_snapshot = store.get_snapshot().unwrap().expect("snapshot should exist"); let loaded_snapshot = store.get_snapshot().unwrap().expect("snapshot should exist");
let loaded_session = store.get_session_id().unwrap(); let loaded_session_ids = store.get_session_ids().unwrap();
let loaded_serial = store.get_serial().unwrap(); let loaded_serial = store.get_serial().unwrap();
let input = format!( let input = format!(
"db_path: {}\nsnapshot:\n{}session_id=42\nserial=100\n", "db_path: {}\nsnapshot:\n{}session_ids={:?}\nserial=100\n",
dir.path().display(), dir.path().display(),
indent_block(&snapshot_to_string(&snapshot), 2), indent_block(&snapshot_to_string(&snapshot), 2),
session_ids,
); );
let output = format!( let output = format!(
"loaded_snapshot:\n{}loaded_session_id: {:?}\nloaded_serial: {:?}\n", "loaded_snapshot:\n{}loaded_session_ids: {:?}\nloaded_serial: {:?}\n",
indent_block(&snapshot_to_string(&loaded_snapshot), 2), indent_block(&snapshot_to_string(&loaded_snapshot), 2),
loaded_session, loaded_session_ids,
loaded_serial, loaded_serial,
); );
test_report( test_report(
"store_db_save_snapshot_and_meta_writes_all_fields", "store_db_save_snapshot_and_meta_writes_all_fields",
"验证 save_snapshot_and_meta() 会同时写入 snapshot、session_id 和 serial。", "验证 save_snapshot_and_meta() 会同时写入 snapshot、session_ids 和 serial。",
&input, &input,
&output, &output,
); );
assert!(snapshot.same_content(&loaded_snapshot)); assert!(snapshot.same_content(&loaded_snapshot));
assert_eq!(loaded_session, Some(42)); assert_eq!(loaded_session_ids, Some(session_ids));
assert_eq!(loaded_serial, Some(100)); assert_eq!(loaded_serial, Some(100));
} }
#[test]
fn store_db_save_cache_state_writes_delta_snapshot_meta_and_window_together() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let session_ids = SessionIds::from_array([40, 41, 42]);
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)),
]);
let delta = Delta::new(
101,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))],
);
store
.save_cache_state(
CacheAvailability::Ready,
&snapshot,
&session_ids,
101,
Some(&delta),
Some((101, 101)),
false,
)
.unwrap();
let loaded_snapshot = store.get_snapshot().unwrap().expect("snapshot should exist");
let loaded_session_ids = store.get_session_ids().unwrap();
let loaded_serial = store.get_serial().unwrap();
let loaded_availability = store.get_availability().unwrap();
let loaded_delta = store.get_delta(101).unwrap().expect("delta should exist");
let loaded_window = store.get_delta_window().unwrap();
assert!(snapshot.same_content(&loaded_snapshot));
assert_eq!(loaded_session_ids, Some(session_ids));
assert_eq!(loaded_serial, Some(101));
assert_eq!(loaded_availability, Some(CacheAvailability::Ready));
assert_eq!(loaded_delta.serial(), 101);
assert_eq!(loaded_window, Some((101, 101)));
}
#[test]
fn store_db_save_cache_state_prunes_deltas_older_than_window_min() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let session_ids = SessionIds::from_array([40, 41, 42]);
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
]);
let d101 = Delta::new(
101,
vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))],
vec![],
);
let d102 = Delta::new(
102,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![],
);
let d103 = Delta::new(
103,
vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))],
vec![],
);
store
.save_cache_state(
CacheAvailability::Ready,
&snapshot,
&session_ids,
101,
Some(&d101),
Some((101, 101)),
false,
)
.unwrap();
store
.save_cache_state(
CacheAvailability::Ready,
&snapshot,
&session_ids,
102,
Some(&d102),
Some((101, 102)),
false,
)
.unwrap();
store
.save_cache_state(
CacheAvailability::Ready,
&snapshot,
&session_ids,
103,
Some(&d103),
Some((103, 103)),
false,
)
.unwrap();
assert!(store.get_delta(101).unwrap().is_none());
assert!(store.get_delta(102).unwrap().is_none());
assert_eq!(store.get_delta(103).unwrap().map(|d| d.serial()), Some(103));
assert_eq!(store.get_delta_window().unwrap(), Some((103, 103)));
}
#[test]
fn store_db_load_delta_window_restores_wraparound_window_in_serial_order() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let session_ids = SessionIds::from_array([40, 41, 42]);
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
]);
let d_max = Delta::new(
u32::MAX,
vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))],
vec![],
);
let d_zero = Delta::new(
0,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![],
);
let d_one = Delta::new(
1,
vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))],
vec![],
);
store
.save_cache_state(
CacheAvailability::Ready,
&snapshot,
&session_ids,
u32::MAX,
Some(&d_max),
Some((u32::MAX, u32::MAX)),
false,
)
.unwrap();
store
.save_cache_state(
CacheAvailability::Ready,
&snapshot,
&session_ids,
0,
Some(&d_zero),
Some((u32::MAX, 0)),
false,
)
.unwrap();
store
.save_cache_state(
CacheAvailability::Ready,
&snapshot,
&session_ids,
1,
Some(&d_one),
Some((u32::MAX, 1)),
false,
)
.unwrap();
let loaded = store.load_delta_window(u32::MAX, 1).unwrap();
assert_eq!(loaded.iter().map(Delta::serial).collect::<Vec<_>>(), vec![u32::MAX, 0, 1]);
}
#[test] #[test]
fn store_db_load_snapshot_and_serial_returns_consistent_pair() { fn store_db_load_snapshot_and_serial_returns_consistent_pair() {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
@ -349,4 +541,4 @@ fn store_db_load_snapshot_and_serial_errors_on_inconsistent_state() {
); );
assert!(result.is_err()); assert!(result.is_err());
} }