RTR server以及client的开发、test的补充

This commit is contained in:
xiuting.xu 2026-03-17 16:36:42 +08:00
parent 251dea8e5e
commit 9cbea4e2d0
31 changed files with 4970 additions and 260 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
target/
Cargo.lock
rtr-db/

View File

@ -18,9 +18,15 @@ chrono = "0.4.44"
bytes = "1.11.1"
tokio = { version = "1.49.0", features = ["full"] }
rand = "0.10.0"
rocksdb = "0.21"
serde = { version = "1", features = ["derive"] }
rocksdb = { version = "0.21.0", default-features = false }
serde = { version = "1", features = ["derive", "rc"] }
serde_json = "1"
anyhow = "1"
bincode = "3.0.0"
tracing = "0.1.44"
sha2 = "0.10"
tempfile = "3"
tokio-rustls = "0.26"
rustls = "0.23"
rustls-pemfile = "2"
rustls-pki-types = "1.14.0"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }

View File

@ -0,0 +1,63 @@
# rtr_debug_client
`rtr_debug_client` 是一个用于调试和联调 RTRRPKI-to-Router服务端的小型客户端工具。
它的目标不是做一个完整的生产级 router client而是提供一个简单、直接、可观察的调试入口用于
- 连接 RTR server
- 发送 `Reset Query``Serial Query`
- 接收并打印服务端返回的 PDU
- 辅助排查协议实现、会话状态、序列号增量、PDU 编码等问题
---
## 适用场景
这个工具适合以下场景:
- 开发 RTR server 时做本地联调
- 验证服务端是否正确返回 `Cache Response`
- 检查 `IPv4 Prefix` / `IPv6 Prefix` / `ASPA` / `End of Data` 等 PDU
- 验证 `Serial Query` 路径是否正确
- 观察异常响应,例如 `Cache Reset``Error Report`
- 后续扩展为支持 TLS、自动断言、会话统计等调试能力
---
## 当前能力
当前版本支持:
- TCP 连接 RTR server
- 发送 `Reset Query`
- 发送 `Serial Query`
- 持续读取服务端返回的 PDU
- 解析并打印以下常见 PDU
- `Serial Notify`
- `Serial Query`
- `Reset Query`
- `Cache Response`
- `IPv4 Prefix`
- `IPv6 Prefix`
- `End of Data`
- `Cache Reset`
- `Error Report`
- `ASPA`
- 基础长度校验
- 最大 PDU 长度限制,防止异常数据导致过大内存分配
---
## 目录结构
建议目录如下:
```text
src/
└── bin/
└── rtr_debug_client/
├── main.rs
├── protocol.rs
├── io.rs
├── pretty.rs
└── README.md

View File

@ -0,0 +1,582 @@
use std::env;
use std::io;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::time::{timeout, Duration, Instant};
mod wire;
mod pretty;
mod protocol;
use crate::wire::{read_pdu, send_reset_query, send_serial_query};
use crate::pretty::{
parse_end_of_data_info, parse_serial_notify_serial, print_pdu,
};
use crate::protocol::{PduHeader, PduType, QueryMode};
const DEFAULT_READ_TIMEOUT_SECS: u64 = 30;
const DEFAULT_POLL_INTERVAL_SECS: u64 = 60;
#[tokio::main]
async fn main() -> io::Result<()> {
let config = Config::from_args()?;
println!("== RTR debug client ==");
println!("target : {}", config.addr);
println!("version : {}", config.version);
println!("timeout : {}s", config.read_timeout_secs);
println!("poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs);
match &config.mode {
QueryMode::Reset => {
println!("mode : reset");
}
QueryMode::Serial { session_id, serial } => {
println!("mode : serial");
println!("session : {}", session_id);
println!("serial : {}", serial);
}
}
println!();
print_help();
let stream = TcpStream::connect(&config.addr).await?;
println!("connected to {}", config.addr);
let (mut reader, mut writer) = stream.into_split();
let mut state = ClientState::new(
config.version,
config.read_timeout_secs,
config.default_poll_secs,
);
match config.mode {
QueryMode::Reset => {
send_reset_query(&mut writer, config.version).await?;
println!("sent Reset Query");
}
QueryMode::Serial { session_id, serial } => {
state.session_id = Some(session_id);
state.serial = Some(serial);
send_serial_query(&mut writer, config.version, session_id, serial).await?;
println!("sent Serial Query");
}
}
state.schedule_next_poll();
println!();
let stdin = tokio::io::stdin();
let mut stdin_lines = BufReader::new(stdin).lines();
loop {
let poll_sleep = tokio::time::sleep_until(state.next_poll_deadline);
tokio::pin!(poll_sleep);
tokio::select! {
line = stdin_lines.next_line() => {
match line {
Ok(Some(line)) => {
let should_quit = handle_console_command(
&line,
&mut writer,
&mut state,
).await?;
if should_quit {
println!("quit requested, closing client.");
break;
}
}
Ok(None) => {
println!("stdin closed, continue network loop.");
}
Err(err) => {
eprintln!("read stdin failed: {}", err);
}
}
}
_ = &mut poll_sleep => {
handle_poll_tick(&mut writer, &mut state).await?;
state.schedule_next_poll();
}
read_result = timeout(
Duration::from_secs(state.read_timeout_secs),
read_pdu(&mut reader)
) => {
match read_result {
Ok(Ok(pdu)) => {
print_pdu(&pdu.header, &pdu.body);
handle_incoming_pdu(&mut writer, &mut state, &pdu.header, &pdu.body).await?;
}
Ok(Err(err)) => {
eprintln!("read PDU failed: {}", err);
return Err(err);
}
Err(_) => {
println!(
"[timeout] no PDU received in {}s, connection kept open.",
state.read_timeout_secs
);
}
}
}
}
}
Ok(())
}
async fn handle_incoming_pdu(
writer: &mut OwnedWriteHalf,
state: &mut ClientState,
header: &PduHeader,
body: &[u8],
) -> io::Result<()> {
match header.pdu_type() {
PduType::CacheResponse => {
state.current_session_id = Some(header.session_id());
}
PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::Aspa => {
if state.current_session_id.is_none() {
state.current_session_id = Some(header.session_id());
}
}
PduType::EndOfData => {
let session_id = header.session_id();
let eod = parse_end_of_data_info(body);
state.session_id = Some(session_id);
state.current_session_id = Some(session_id);
println!();
if let Some(eod) = eod {
state.serial = Some(eod.serial);
state.refresh = eod.refresh;
state.retry = eod.retry;
state.expire = eod.expire;
println!(
"updated client state: session_id={}, serial={}",
session_id, eod.serial
);
if let Some(refresh) = eod.refresh {
println!("refresh : {}", refresh);
}
if let Some(retry) = eod.retry {
println!("retry : {}", retry);
}
if let Some(expire) = eod.expire {
println!("expire : {}", expire);
}
state.schedule_next_poll();
println!(
"next auto poll scheduled after {}s",
state.effective_poll_secs()
);
} else {
println!(
"updated client state: session_id={}, serial=<unknown>",
session_id
);
}
println!("received EndOfData, keep connection open.");
println!();
}
PduType::SerialNotify => {
let notify_session_id = header.session_id();
let notify_serial = parse_serial_notify_serial(body);
println!();
match (state.session_id, state.serial, notify_serial) {
(Some(current_session_id), Some(current_serial), Some(_new_serial))
if current_session_id == notify_session_id =>
{
println!(
"received Serial Notify for current session {}, send Serial Query with serial {}",
current_session_id, current_serial
);
send_serial_query(
writer,
state.version,
current_session_id,
current_serial,
)
.await?;
}
_ => {
println!(
"received Serial Notify but local session/serial state is not usable, send Reset Query"
);
send_reset_query(writer, state.version).await?;
}
}
state.schedule_next_poll();
println!();
}
PduType::CacheReset => {
println!();
println!("received Cache Reset, send Reset Query");
state.current_session_id = None;
state.serial = None;
send_reset_query(writer, state.version).await?;
state.schedule_next_poll();
println!();
}
PduType::ErrorReport => {
println!();
println!("received Error Report, keep connection open for debugging.");
if let Some(retry) = state.retry {
println!("will keep auto polling; server retry hint currently stored: {}s", retry);
}
println!();
}
PduType::SerialQuery | PduType::ResetQuery | PduType::Unknown(_) => {
// only print, no extra action
}
}
Ok(())
}
async fn handle_poll_tick(
writer: &mut OwnedWriteHalf,
state: &mut ClientState,
) -> io::Result<()> {
println!();
println!(
"[auto-poll] timer fired (interval={}s)",
state.effective_poll_secs()
);
match (state.session_id, state.serial) {
(Some(session_id), Some(serial)) => {
println!(
"[auto-poll] send Serial Query with session_id={}, serial={}",
session_id, serial
);
send_serial_query(writer, state.version, session_id, serial).await?;
}
_ => {
println!("[auto-poll] local state incomplete, send Reset Query");
send_reset_query(writer, state.version).await?;
}
}
println!();
Ok(())
}
async fn handle_console_command(
line: &str,
writer: &mut OwnedWriteHalf,
state: &mut ClientState,
) -> io::Result<bool> {
let line = line.trim();
if line.is_empty() {
return Ok(false);
}
let parts: Vec<&str> = line.split_whitespace().collect();
match parts.as_slice() {
["help"] => {
print_help();
}
["state"] => {
print_state(state);
}
["reset"] => {
println!("manual command: send Reset Query");
send_reset_query(writer, state.version).await?;
state.schedule_next_poll();
}
["serial"] => {
match (state.session_id, state.serial) {
(Some(session_id), Some(serial)) => {
println!(
"manual command: send Serial Query with current state: session_id={}, serial={}",
session_id, serial
);
send_serial_query(writer, state.version, session_id, serial).await?;
state.schedule_next_poll();
}
_ => {
println!(
"manual command failed: current session_id/serial not available, use `reset` or `serial <session_id> <serial>`"
);
}
}
}
["serial", session_id, serial] => {
let session_id = match session_id.parse::<u16>() {
Ok(v) => v,
Err(err) => {
println!("invalid session_id: {}", err);
return Ok(false);
}
};
let serial = match serial.parse::<u32>() {
Ok(v) => v,
Err(err) => {
println!("invalid serial: {}", err);
return Ok(false);
}
};
println!(
"manual command: send Serial Query with explicit args: session_id={}, serial={}",
session_id, serial
);
state.session_id = Some(session_id);
state.serial = Some(serial);
send_serial_query(writer, state.version, session_id, serial).await?;
state.schedule_next_poll();
}
["timeout"] => {
println!("current read timeout: {}s", state.read_timeout_secs);
}
["timeout", secs] => {
let secs = match secs.parse::<u64>() {
Ok(v) if v > 0 => v,
Ok(_) => {
println!("timeout must be > 0");
return Ok(false);
}
Err(err) => {
println!("invalid timeout seconds: {}", err);
return Ok(false);
}
};
state.read_timeout_secs = secs;
println!("updated read timeout to {}s", state.read_timeout_secs);
}
["poll"] => {
println!(
"current effective poll interval: {}s",
state.effective_poll_secs()
);
println!("stored refresh hint : {:?}", state.refresh);
println!("default poll interval : {}s", state.default_poll_secs);
}
["poll", secs] => {
let secs = match secs.parse::<u64>() {
Ok(v) if v > 0 => v,
Ok(_) => {
println!("poll interval must be > 0");
return Ok(false);
}
Err(err) => {
println!("invalid poll seconds: {}", err);
return Ok(false);
}
};
state.refresh = Some(secs as u32);
state.schedule_next_poll();
println!("updated poll interval to {}s", secs);
}
["quit"] | ["exit"] => {
return Ok(true);
}
_ => {
println!("unknown command: {}", line);
print_help();
}
}
Ok(false)
}
fn print_help() {
println!("available commands:");
println!(" help show this help");
println!(" state print current client state");
println!(" reset send Reset Query");
println!(" serial send Serial Query with current session_id/serial");
println!(" serial <sid> <serial> send Serial Query with explicit values");
println!(" timeout show current read timeout");
println!(" timeout <secs> update read timeout seconds");
println!(" poll show current poll interval");
println!(" poll <secs> override poll interval seconds");
println!(" quit exit client");
println!();
}
fn print_state(state: &ClientState) {
println!("client state:");
println!(" version : {}", state.version);
println!(" session_id : {:?}", state.session_id);
println!(" serial : {:?}", state.serial);
println!(" current_session_id : {:?}", state.current_session_id);
println!(" refresh : {:?}", state.refresh);
println!(" retry : {:?}", state.retry);
println!(" expire : {:?}", state.expire);
println!(" read_timeout_secs : {}", state.read_timeout_secs);
println!(" default_poll_secs : {}", state.default_poll_secs);
println!(" effective_poll_secs: {}", state.effective_poll_secs());
println!();
}
#[derive(Debug)]
struct ClientState {
version: u8,
session_id: Option<u16>,
serial: Option<u32>,
current_session_id: Option<u16>,
refresh: Option<u32>,
retry: Option<u32>,
expire: Option<u32>,
read_timeout_secs: u64,
default_poll_secs: u64,
next_poll_deadline: Instant,
}
impl ClientState {
fn new(version: u8, read_timeout_secs: u64, default_poll_secs: u64) -> Self {
Self {
version,
session_id: None,
serial: None,
current_session_id: None,
refresh: None,
retry: None,
expire: None,
read_timeout_secs,
default_poll_secs,
next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs),
}
}
fn effective_poll_secs(&self) -> u64 {
self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs)
}
fn schedule_next_poll(&mut self) {
self.next_poll_deadline =
Instant::now() + Duration::from_secs(self.effective_poll_secs());
}
}
#[derive(Debug)]
struct Config {
addr: String,
version: u8,
mode: QueryMode,
read_timeout_secs: u64,
default_poll_secs: u64,
}
impl Config {
fn from_args() -> io::Result<Self> {
let mut args = env::args().skip(1);
let addr = args
.next()
.unwrap_or_else(|| "127.0.0.1:3323".to_string());
let version = args
.next()
.map(|s| {
s.parse::<u8>().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid version '{}': {}", s, e),
)
})
})
.transpose()?
.unwrap_or(1);
if version > 2 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unsupported RTR version {}, expected 0..=2", version),
));
}
let mode = match args.next().as_deref() {
None | Some("reset") => QueryMode::Reset,
Some("serial") => {
let session_id = args
.next()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"serial mode requires session_id and serial",
)
})?
.parse::<u16>()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid session_id: {}", e),
)
})?;
let serial = args
.next()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"serial mode requires serial",
)
})?
.parse::<u32>()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid serial: {}", e),
)
})?;
QueryMode::Serial { session_id, serial }
}
Some(other) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid mode '{}', expected 'reset' or 'serial'", other),
));
}
};
Ok(Self {
addr,
version,
mode,
read_timeout_secs: DEFAULT_READ_TIMEOUT_SECS,
default_poll_secs: DEFAULT_POLL_INTERVAL_SECS,
})
}
}

View File

@ -0,0 +1,303 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use crate::protocol::{
flag_meaning, hex_bytes, PduHeader, PduType, ASPA_FIXED_BODY_LEN,
END_OF_DATA_V0_BODY_LEN, END_OF_DATA_V1_BODY_LEN, IPV4_PREFIX_BODY_LEN,
IPV6_PREFIX_BODY_LEN, ROUTER_KEY_FIXED_BODY_LEN,
};
pub fn print_pdu(header: &PduHeader, body: &[u8]) {
println!("--------------------------------------------------");
println!("PDU: {}", header.pdu_type());
println!("version : {}", header.version);
println!("length : {}", header.length);
match header.pdu_type() {
PduType::CacheResponse => {
println!("session_id : {}", header.session_id());
}
PduType::CacheReset => {
println!("cache reset");
}
PduType::Ipv4Prefix => {
print_ipv4_prefix(header, body);
}
PduType::Ipv6Prefix => {
print_ipv6_prefix(header, body);
}
PduType::EndOfData => {
print_end_of_data(header, body);
}
PduType::ErrorReport => {
print_error_report(header, body);
}
PduType::SerialNotify => {
print_serial_notify(header, body);
}
PduType::SerialQuery => {
print_serial_query(header, body);
}
PduType::Aspa => {
print_aspa(header, body);
}
PduType::ResetQuery => {
println!("reset query");
}
PduType::Unknown(_) => {
println!("field1 : {}", header.field1);
println!("body : {}", hex_bytes(body));
}
}
}
fn print_ipv4_prefix(header: &PduHeader, body: &[u8]) {
if body.len() != IPV4_PREFIX_BODY_LEN {
println!("invalid IPv4 Prefix body length: {}", body.len());
println!("raw body: {}", hex_bytes(body));
return;
}
let flags = body[0];
let prefix_len = body[1];
let max_len = body[2];
let zero = body[3];
let prefix = Ipv4Addr::new(body[4], body[5], body[6], body[7]);
let asn = u32::from_be_bytes([body[8], body[9], body[10], body[11]]);
println!("session_id : {}", header.session_id());
println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags));
println!("prefix_len : {}", prefix_len);
println!("max_len : {}", max_len);
println!("zero : {}", zero);
println!("prefix : {}", prefix);
println!("asn : {}", asn);
}
fn print_ipv6_prefix(header: &PduHeader, body: &[u8]) {
if body.len() != IPV6_PREFIX_BODY_LEN {
println!("invalid IPv6 Prefix body length: {}", body.len());
println!("raw body: {}", hex_bytes(body));
return;
}
let flags = body[0];
let prefix_len = body[1];
let max_len = body[2];
let zero = body[3];
let mut addr = [0u8; 16];
addr.copy_from_slice(&body[4..20]);
let prefix = Ipv6Addr::from(addr);
let asn = u32::from_be_bytes([body[20], body[21], body[22], body[23]]);
println!("session_id : {}", header.session_id());
println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags));
println!("prefix_len : {}", prefix_len);
println!("max_len : {}", max_len);
println!("zero : {}", zero);
println!("prefix : {}", prefix);
println!("asn : {}", asn);
}
fn print_end_of_data(header: &PduHeader, body: &[u8]) {
println!("session_id : {}", header.session_id());
match body.len() {
END_OF_DATA_V0_BODY_LEN => {
let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
println!("serial : {}", serial);
println!("variant : v0");
}
END_OF_DATA_V1_BODY_LEN => {
let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
let refresh = u32::from_be_bytes([body[4], body[5], body[6], body[7]]);
let retry = u32::from_be_bytes([body[8], body[9], body[10], body[11]]);
let expire = u32::from_be_bytes([body[12], body[13], body[14], body[15]]);
println!("serial : {}", serial);
println!("refresh : {}", refresh);
println!("retry : {}", retry);
println!("expire : {}", expire);
println!("variant : v1/v2");
}
_ => {
println!("invalid EndOfData body length: {}", body.len());
println!("raw body : {}", hex_bytes(body));
}
}
}
fn print_error_report(header: &PduHeader, body: &[u8]) {
println!("error_code : {}", header.error_code());
if body.len() < 8 {
println!("invalid ErrorReport body length: {}", body.len());
println!("raw body : {}", hex_bytes(body));
return;
}
let encapsulated_len =
u32::from_be_bytes([body[0], body[1], body[2], body[3]]) as usize;
if body.len() < 4 + encapsulated_len + 4 {
println!("invalid ErrorReport: truncated encapsulated PDU");
println!("raw body : {}", hex_bytes(body));
return;
}
let encapsulated = &body[4..4 + encapsulated_len];
let text_len_offset = 4 + encapsulated_len;
let text_len = u32::from_be_bytes([
body[text_len_offset],
body[text_len_offset + 1],
body[text_len_offset + 2],
body[text_len_offset + 3],
]) as usize;
if body.len() < text_len_offset + 4 + text_len {
println!("invalid ErrorReport: truncated text");
println!("raw body : {}", hex_bytes(body));
return;
}
let text_bytes = &body[text_len_offset + 4..text_len_offset + 4 + text_len];
let text = String::from_utf8_lossy(text_bytes);
println!("encap_len : {}", encapsulated_len);
println!("encap_pdu : {}", hex_bytes(encapsulated));
println!("text_len : {}", text_len);
println!("text : {}", text);
}
fn print_serial_notify(header: &PduHeader, body: &[u8]) {
if body.len() != 4 {
println!("invalid Serial Notify body length: {}", body.len());
println!("raw body : {}", hex_bytes(body));
return;
}
let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
println!("session_id : {}", header.session_id());
println!("serial : {}", serial);
}
fn print_serial_query(header: &PduHeader, body: &[u8]) {
if body.len() != 4 {
println!("invalid Serial Query body length: {}", body.len());
println!("raw body : {}", hex_bytes(body));
return;
}
let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
println!("session_id : {}", header.session_id());
println!("serial : {}", serial);
}
#[allow(dead_code)]
fn print_router_key(header: &PduHeader, body: &[u8]) {
println!("session_id : {}", header.session_id());
if body.len() < ROUTER_KEY_FIXED_BODY_LEN {
println!("invalid Router Key body length: {}", body.len());
println!("raw body : {}", hex_bytes(body));
return;
}
let flags = body[0];
let zero = body[1];
let ski = &body[2..22];
let asn = u32::from_be_bytes([body[22], body[23], body[24], body[25]]);
let spki = &body[26..];
println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags));
println!("zero : {}", zero);
println!("ski : {}", hex_bytes(ski));
println!("asn : {}", asn);
println!("spki_len : {}", spki.len());
println!("spki : {}", hex_bytes(spki));
}
fn print_aspa(header: &PduHeader, body: &[u8]) {
println!("session_id : {}", header.session_id());
if body.len() < ASPA_FIXED_BODY_LEN {
println!("invalid ASPA body length: {}", body.len());
println!("raw body : {}", hex_bytes(body));
return;
}
let flags = body[0];
let zero1 = body[1];
let zero2 = body[2];
let zero3 = body[3];
let customer_asn = u32::from_be_bytes([body[4], body[5], body[6], body[7]]);
println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags));
println!("reserved : [{}, {}, {}]", zero1, zero2, zero3);
println!("customer_as : {}", customer_asn);
let providers_raw = &body[8..];
if providers_raw.len() % 4 != 0 {
println!("invalid ASPA providers length: {}", providers_raw.len());
println!("providers : {}", hex_bytes(providers_raw));
return;
}
let mut providers = Vec::new();
for chunk in providers_raw.chunks_exact(4) {
providers.push(u32::from_be_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
println!("providers : {:?}", providers);
}
pub fn parse_serial_notify_serial(body: &[u8]) -> Option<u32> {
if body.len() != 4 {
return None;
}
Some(u32::from_be_bytes([body[0], body[1], body[2], body[3]]))
}
pub fn parse_end_of_data_serial(body: &[u8]) -> Option<u32> {
match body.len() {
4 | 16 => Some(u32::from_be_bytes([body[0], body[1], body[2], body[3]])),
_ => None,
}
}
#[derive(Debug, Clone, Copy)]
pub struct EndOfDataInfo {
pub serial: u32,
pub refresh: Option<u32>,
pub retry: Option<u32>,
pub expire: Option<u32>,
}
pub fn parse_end_of_data_info(body: &[u8]) -> Option<EndOfDataInfo> {
match body.len() {
4 => {
let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
Some(EndOfDataInfo {
serial,
refresh: None,
retry: None,
expire: None,
})
}
16 => {
let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
let refresh = u32::from_be_bytes([body[4], body[5], body[6], body[7]]);
let retry = u32::from_be_bytes([body[8], body[9], body[10], body[11]]);
let expire = u32::from_be_bytes([body[12], body[13], body[14], body[15]]);
Some(EndOfDataInfo {
serial,
refresh: Some(refresh),
retry: Some(retry),
expire: Some(expire),
})
}
_ => None,
}
}

View File

@ -0,0 +1,155 @@
use std::fmt;
pub const HEADER_LEN: usize = 8;
pub const SERIAL_QUERY_LEN: usize = 12;
pub const MAX_PDU_LEN: u32 = 1024 * 1024; // 1 MiB
pub const IPV4_PREFIX_BODY_LEN: usize = 12;
pub const IPV6_PREFIX_BODY_LEN: usize = 24;
pub const END_OF_DATA_V0_BODY_LEN: usize = 4;
pub const END_OF_DATA_V1_BODY_LEN: usize = 16;
pub const ROUTER_KEY_FIXED_BODY_LEN: usize = 26;
pub const ASPA_FIXED_BODY_LEN: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryMode {
Reset,
Serial { session_id: u16, serial: u32 },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PduType {
SerialNotify,
SerialQuery,
ResetQuery,
CacheResponse,
Ipv4Prefix,
Ipv6Prefix,
EndOfData,
CacheReset,
ErrorReport,
Aspa,
Unknown(u8),
}
impl PduType {
pub fn code(self) -> u8 {
match self {
Self::SerialNotify => 0,
Self::SerialQuery => 1,
Self::ResetQuery => 2,
Self::CacheResponse => 3,
Self::Ipv4Prefix => 4,
Self::Ipv6Prefix => 6,
Self::EndOfData => 7,
Self::CacheReset => 8,
Self::ErrorReport => 10,
Self::Aspa => 11,
Self::Unknown(v) => v,
}
}
pub fn name(self) -> &'static str {
match self {
Self::SerialNotify => "Serial Notify",
Self::SerialQuery => "Serial Query",
Self::ResetQuery => "Reset Query",
Self::CacheResponse => "Cache Response",
Self::Ipv4Prefix => "IPv4 Prefix",
Self::Ipv6Prefix => "IPv6 Prefix",
Self::EndOfData => "End of Data",
Self::CacheReset => "Cache Reset",
Self::ErrorReport => "Error Report",
Self::Aspa => "ASPA",
Self::Unknown(_) => "Unknown",
}
}
}
impl From<u8> for PduType {
fn from(value: u8) -> Self {
match value {
0 => Self::SerialNotify,
1 => Self::SerialQuery,
2 => Self::ResetQuery,
3 => Self::CacheResponse,
4 => Self::Ipv4Prefix,
6 => Self::Ipv6Prefix,
7 => Self::EndOfData,
8 => Self::CacheReset,
10 => Self::ErrorReport,
11 => Self::Aspa,
x => Self::Unknown(x),
}
}
}
impl fmt::Display for PduType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unknown(v) => write!(f, "{} ({})", self.name(), v),
_ => write!(f, "{} ({})", self.name(), self.code()),
}
}
}
#[derive(Debug, Clone)]
pub struct PduHeader {
pub version: u8,
pub pdu_type_raw: u8,
pub field1: u16,
pub length: u32,
}
impl PduHeader {
pub fn from_bytes(buf: [u8; HEADER_LEN]) -> Self {
Self {
version: buf[0],
pdu_type_raw: buf[1],
field1: u16::from_be_bytes([buf[2], buf[3]]),
length: u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]),
}
}
pub fn pdu_type(&self) -> PduType {
self.pdu_type_raw.into()
}
pub fn session_id(&self) -> u16 {
self.field1
}
pub fn error_code(&self) -> u16 {
self.field1
}
}
#[derive(Debug, Clone)]
pub struct RawPdu {
pub header: PduHeader,
pub body: Vec<u8>,
}
pub fn flag_meaning(flags: u8) -> &'static str {
if flags & 0x01 == 0x01 {
"announcement"
} else {
"withdrawal"
}
}
pub fn hex_bytes(data: &[u8]) -> String {
if data.is_empty() {
return "<empty>".to_string();
}
let mut out = String::with_capacity(data.len() * 3 - 1);
for (idx, b) in data.iter().enumerate() {
if idx > 0 {
out.push(' ');
}
use std::fmt::Write as _;
let _ = write!(out, "{:02x}", b);
}
out
}

View File

@ -0,0 +1,81 @@
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::protocol::{
PduHeader, PduType, RawPdu, HEADER_LEN, MAX_PDU_LEN, SERIAL_QUERY_LEN,
};
pub async fn send_reset_query<S>(stream: &mut S, version: u8) -> io::Result<()>
where
S: AsyncWrite + Unpin,
{
let mut buf = [0u8; HEADER_LEN];
buf[0] = version;
buf[1] = PduType::ResetQuery.code();
buf[2..4].copy_from_slice(&0u16.to_be_bytes());
buf[4..8].copy_from_slice(&(HEADER_LEN as u32).to_be_bytes());
stream.write_all(&buf).await?;
stream.flush().await
}
pub async fn send_serial_query<S>(
stream: &mut S,
version: u8,
session_id: u16,
serial: u32,
) -> io::Result<()>
where
S: AsyncWrite + Unpin,
{
let mut buf = [0u8; SERIAL_QUERY_LEN];
buf[0] = version;
buf[1] = PduType::SerialQuery.code();
buf[2..4].copy_from_slice(&session_id.to_be_bytes());
buf[4..8].copy_from_slice(&(SERIAL_QUERY_LEN as u32).to_be_bytes());
buf[8..12].copy_from_slice(&serial.to_be_bytes());
stream.write_all(&buf).await?;
stream.flush().await
}
pub async fn read_header<S>(stream: &mut S) -> io::Result<PduHeader>
where
S: AsyncRead + Unpin,
{
let mut buf = [0u8; HEADER_LEN];
stream.read_exact(&mut buf).await?;
Ok(PduHeader::from_bytes(buf))
}
pub async fn read_pdu<S>(stream: &mut S) -> io::Result<RawPdu>
where
S: AsyncRead + Unpin,
{
let header = read_header(stream).await?;
if header.length < HEADER_LEN as u32 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"invalid PDU length {} < {}",
header.length, HEADER_LEN
),
));
}
if header.length > MAX_PDU_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"PDU length {} exceeds max allowed {}",
header.length, MAX_PDU_LEN
),
));
}
let body_len = header.length as usize - HEADER_LEN;
let mut body = vec![0u8; body_len];
stream.read_exact(&mut body).await?;
Ok(RawPdu { header, body })
}

View File

@ -1,3 +1,4 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ASIdentifiers {
@ -54,7 +55,7 @@ impl ASRange {
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default)]
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default, Serialize, Deserialize)]
pub struct Asn(u32);
impl Asn {

View File

@ -1,71 +1,157 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct IPAddrBlocks {
pub ips: Vec<IPAddressFamily>
pub ips: Vec<IPAddressFamily>,
}
// IP Address Family
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct IPAddressFamily {
pub address_family: Afi,
pub ip_address_choice: IPAddressChoice,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)]
pub enum Afi {
Ipv4,
Ipv6,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum IPAddressChoice {
Inherit,
AddressOrRange(Vec<IPAddressOrRange>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum IPAddressOrRange {
AddressPrefix(IPAddressPrefix),
AddressRange(IPAddressRange),
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct IPAddressPrefix {
pub address: IPAddress,
pub prefix_length: u8,
}
#[derive(Debug, Clone, PartialEq, Eq)]
impl IPAddressPrefix {
pub fn new(address: IPAddress, prefix_length: u8) -> Self {
Self {
address,
prefix_length,
}
}
pub fn is_ipv4(&self) -> bool {
self.address.is_ipv4()
}
pub fn is_ipv6(&self) -> bool {
self.address.is_ipv6()
}
pub fn afi(&self) -> Afi {
self.address.afi()
}
pub fn address(&self) -> IPAddress {
self.address
}
pub fn prefix_length(&self) -> u8 {
self.prefix_length
}
}
#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct IPAddressRange {
pub min: IPAddress,
pub max: IPAddress,
}
use std::net::{Ipv4Addr, Ipv6Addr};
impl IPAddressRange {
pub fn new(min: IPAddress, max: IPAddress) -> Self {
Self { min, max }
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct IPAddress(u128);
pub fn min(&self) -> IPAddress {
self.min
}
pub fn max(&self) -> IPAddress {
self.max
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize, Deserialize)]
pub enum IPAddress {
V4(Ipv4Addr),
V6(Ipv6Addr),
}
impl IPAddress {
pub fn from_ipv4(addr: Ipv4Addr) -> Self {
Self::V4(addr)
}
pub fn from_ipv6(addr: Ipv6Addr) -> Self {
Self::V6(addr)
}
pub fn to_ipv4(self) -> Option<Ipv4Addr> {
if self.0 <= u32::MAX as u128 {
Some(Ipv4Addr::from(self.0 as u32))
} else {
None
match self {
Self::V4(addr) => Some(addr),
Self::V6(_) => None,
}
}
pub fn to_ipv6(self) -> Ipv6Addr {
Ipv6Addr::from(self.0)
pub fn to_ipv6(self) -> Option<Ipv6Addr> {
match self {
Self::V4(_) => None,
Self::V6(addr) => Some(addr),
}
}
pub fn is_ipv4(self) -> bool {
self.0 <= u32::MAX as u128
matches!(self, Self::V4(_))
}
pub fn is_ipv6(self) -> bool {
matches!(self, Self::V6(_))
}
pub fn afi(self) -> Afi {
match self {
Self::V4(_) => Afi::Ipv4,
Self::V6(_) => Afi::Ipv6,
}
}
/// Returns the numeric address value.
///
/// For IPv4, this is the 32-bit address widened to u128.
/// For IPv6, this is the full 128-bit address.
pub fn as_u128(self) -> u128 {
self.0
match self {
Self::V4(addr) => u32::from(addr) as u128,
Self::V6(addr) => u128::from(addr),
}
}
pub fn as_v4_u32(self) -> Option<u32> {
match self {
Self::V4(addr) => Some(u32::from(addr)),
Self::V6(_) => None,
}
}
pub fn as_v6_u128(self) -> Option<u128> {
match self {
Self::V4(_) => None,
Self::V6(addr) => Some(u128::from(addr)),
}
}
}

View File

@ -1,3 +1,3 @@
pub(crate) mod ip_resources;
pub(crate) mod as_resources;
pub mod ip_resources;
pub mod as_resources;
pub mod resource;

View File

@ -1,3 +1,3 @@
pub mod data_model;
mod slurm;
mod rtr;
pub mod rtr;

218
src/main.rs Normal file
View File

@ -0,0 +1,218 @@
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use anyhow::{anyhow, Result};
use tokio::task::JoinHandle;
use tracing::{info, warn};
use rpki::rtr::cache::{RtrCache, SharedRtrCache};
use rpki::rtr::loader::load_vrps_from_file;
use rpki::rtr::payload::Timing;
use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService};
use rpki::rtr::store_db::RtrStore;
#[derive(Debug, Clone)]
struct AppConfig {
enable_tls: bool,
tcp_addr: SocketAddr,
tls_addr: SocketAddr,
db_path: String,
vrp_file: String,
tls_cert_path: String,
tls_key_path: String,
max_delta: u8,
refresh_interval: Duration,
service_config: RtrServiceConfig,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
enable_tls: false,
tcp_addr: "0.0.0.0:3323".parse().expect("invalid default tcp_addr"),
tls_addr: "0.0.0.0:3324".parse().expect("invalid default tls_addr"),
db_path: "./rtr-db".to_string(),
vrp_file: r"C:\Users\xuxiu\git_code\rpki\data\vrps.txt".to_string(),
tls_cert_path: "./certs/server.crt".to_string(),
tls_key_path: "./certs/server.key".to_string(),
max_delta: 100,
refresh_interval: Duration::from_secs(300),
service_config: RtrServiceConfig {
max_connections: 512,
notify_queue_size: 1024,
},
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
init_tracing();
let config = AppConfig::default();
log_startup_config(&config);
let store = open_store(&config)?;
let shared_cache = init_shared_cache(&config, &store)?;
let service = RtrService::with_config(shared_cache.clone(), config.service_config.clone());
let notifier = service.notifier();
let running = start_servers(&config, &service);
let refresh_task = spawn_refresh_task(&config, shared_cache.clone(), store.clone(), notifier);
wait_for_shutdown().await?;
running.shutdown();
running.wait().await;
refresh_task.abort();
let _ = refresh_task.await;
info!("RTR service stopped");
Ok(())
}
fn open_store(config: &AppConfig) -> Result<RtrStore> {
info!("opening RTR store: {}", config.db_path);
RtrStore::open(&config.db_path)
}
fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result<SharedRtrCache> {
let initial_cache = RtrCache::default().init(
store,
config.max_delta,
Timing::default(),
|| load_vrps_from_file(&config.vrp_file),
)?;
let shared_cache: SharedRtrCache = Arc::new(RwLock::new(initial_cache));
{
let cache = shared_cache
.read()
.map_err(|_| anyhow!("cache read lock poisoned during startup"))?;
info!(
"cache initialized: session_id={}, serial={}",
cache.session_id(),
cache.serial()
);
}
Ok(shared_cache)
}
fn start_servers(config: &AppConfig, service: &RtrService) -> RunningRtrService {
if config.enable_tls {
info!("starting TCP and TLS RTR servers");
service.spawn_tcp_and_tls_from_pem(
config.tcp_addr,
config.tls_addr,
&config.tls_cert_path,
&config.tls_key_path,
)
} else {
info!("starting TCP RTR server");
service.spawn_tcp_only(config.tcp_addr)
}
}
fn spawn_refresh_task(
config: &AppConfig,
shared_cache: SharedRtrCache,
store: RtrStore,
notifier: RtrNotifier,
) -> JoinHandle<()> {
let refresh_interval = config.refresh_interval;
let vrp_file = config.vrp_file.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(refresh_interval);
loop {
interval.tick().await;
match load_vrps_from_file(&vrp_file) {
Ok(payloads) => {
let updated = {
let mut cache = match shared_cache.write() {
Ok(guard) => guard,
Err(_) => {
warn!("cache write lock poisoned during refresh");
continue;
}
};
let old_serial = cache.serial();
match cache.update(payloads, &store) {
Ok(()) => cache.serial() != old_serial,
Err(err) => {
warn!("RTR cache update failed: {:?}", err);
false
}
}
};
if updated {
notifier.notify_cache_updated();
info!("RTR cache updated, serial notify broadcast sent");
}
}
Err(err) => {
warn!("failed to reload VRPs from file {}: {:?}", vrp_file, err);
}
}
}
})
}
async fn wait_for_shutdown() -> Result<()> {
tokio::signal::ctrl_c().await?;
info!("shutdown signal received");
Ok(())
}
fn log_startup_config(config: &AppConfig) {
info!("starting RTR service");
info!("db_path={}", config.db_path);
info!("tcp_addr={}", config.tcp_addr);
info!("tls_enabled={}", config.enable_tls);
if config.enable_tls {
info!("tls_addr={}", config.tls_addr);
info!("tls_cert_path={}", config.tls_cert_path);
info!("tls_key_path={}", config.tls_key_path);
}
info!("vrp_file={}", config.vrp_file);
info!("max_delta={}", config.max_delta);
info!(
"refresh_interval_secs={}",
config.refresh_interval.as_secs()
);
info!(
"max_connections={}",
config.service_config.max_connections
);
info!(
"notify_queue_size={}",
config.service_config.notify_queue_size
);
}
fn init_tracing() {
let _ = tracing_subscriber::fmt()
.with_target(true)
.with_thread_ids(true)
.with_level(true)
.try_init();
}

View File

@ -1,16 +1,22 @@
use std::cmp::Ordering;
use std::collections::{BTreeSet, VecDeque};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey};
use sha2::{Digest, Sha256};
use crate::data_model::resources::ip_resources::IPAddress;
use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Timing};
use crate::rtr::store_db::RtrStore;
const DEFAULT_RETRY_INTERVAL: Duration = Duration::from_secs(600);
const DEFAULT_EXPIRE_INTERVAL: Duration = Duration::from_secs(7200);
pub type SharedRtrCache = Arc<RwLock<RtrCache>>;
#[derive(Debug, Clone)]
pub struct DualTime {
instant: Instant,
@ -79,6 +85,11 @@ pub struct Snapshot {
router_keys: BTreeSet<RouterKey>,
aspas: BTreeSet<Aspa>,
created_at: DualTime,
origins_hash: [u8; 32],
router_keys_hash: [u8; 32],
aspas_hash: [u8; 32],
snapshot_hash: [u8; 32],
}
impl Snapshot {
@ -87,12 +98,18 @@ impl Snapshot {
router_keys: BTreeSet<RouterKey>,
aspas: BTreeSet<Aspa>,
) -> Self {
Snapshot {
let mut snapshot = Snapshot {
origins,
router_keys,
aspas,
created_at: DualTime::now(),
}
origins_hash: [0u8; 32],
router_keys_hash: [0u8; 32],
aspas_hash: [0u8; 32],
snapshot_hash: [0u8; 32],
};
snapshot.recompute_hashes();
snapshot
}
pub fn empty() -> Self {
@ -118,38 +135,86 @@ impl Snapshot {
}
}
Snapshot {
origins,
router_keys,
aspas,
created_at: DualTime::now(),
Snapshot::new(origins, router_keys, aspas)
}
pub fn recompute_hashes(&mut self) {
self.origins_hash = self.compute_origins_hash();
self.router_keys_hash = self.compute_router_keys_hash();
self.aspas_hash = self.compute_aspas_hash();
self.snapshot_hash = self.compute_snapshot_hash();
}
fn compute_origins_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.origins.iter())
}
fn compute_router_keys_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.router_keys.iter())
}
fn compute_aspas_hash(&self) -> [u8; 32] {
Self::hash_ordered_iter(self.aspas.iter())
}
fn compute_snapshot_hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"snapshot:v1");
hasher.update(self.origins_hash);
hasher.update(self.router_keys_hash);
hasher.update(self.aspas_hash);
hasher.finalize().into()
}
fn hash_ordered_iter<'a, T, I>(iter: I) -> [u8; 32]
where
T: Serialize + 'a,
I: IntoIterator<Item = &'a T>,
{
let mut hasher = Sha256::new();
hasher.update(b"set:v1");
for item in iter {
let encoded =
serde_json::to_vec(item).expect("serialize snapshot item for hashing failed");
let len = (encoded.len() as u32).to_be_bytes();
hasher.update(len);
hasher.update(encoded);
}
hasher.finalize().into()
}
pub fn diff(&self, new_snapshot: &Snapshot) -> (Vec<Payload>, Vec<Payload>) {
let mut announced = Vec::new();
let mut withdrawn = Vec::new();
if !self.same_origins(new_snapshot) {
for origin in new_snapshot.origins.difference(&self.origins) {
announced.push(Payload::RouteOrigin(origin.clone()));
}
for origin in self.origins.difference(&new_snapshot.origins) {
withdrawn.push(Payload::RouteOrigin(origin.clone()));
}
}
if !self.same_router_keys(new_snapshot) {
for key in new_snapshot.router_keys.difference(&self.router_keys) {
announced.push(Payload::RouterKey(key.clone()));
}
for key in self.router_keys.difference(&new_snapshot.router_keys) {
withdrawn.push(Payload::RouterKey(key.clone()));
}
}
if !self.same_aspas(new_snapshot) {
for aspa in new_snapshot.aspas.difference(&self.aspas) {
announced.push(Payload::Aspa(aspa.clone()));
}
for aspa in self.aspas.difference(&new_snapshot.aspas) {
withdrawn.push(Payload::Aspa(aspa.clone()));
}
}
(announced, withdrawn)
}
@ -169,6 +234,66 @@ impl Snapshot {
v
}
/// Payloads sorted for RTR full snapshot sending.
/// Snapshot represents current valid state, so all payloads are treated as announcements.
pub fn payloads_for_rtr(&self) -> Vec<Payload> {
let mut payloads = self.payloads();
sort_payloads_for_rtr(&mut payloads, true);
payloads
}
pub fn origins_hash(&self) -> [u8; 32] {
self.origins_hash
}
pub fn router_keys_hash(&self) -> [u8; 32] {
self.router_keys_hash
}
pub fn aspas_hash(&self) -> [u8; 32] {
self.aspas_hash
}
pub fn snapshot_hash(&self) -> [u8; 32] {
self.snapshot_hash
}
pub fn same_origins(&self, other: &Self) -> bool {
self.origins_hash == other.origins_hash
}
pub fn same_router_keys(&self, other: &Self) -> bool {
self.router_keys_hash == other.router_keys_hash
}
pub fn same_aspas(&self, other: &Self) -> bool {
self.aspas_hash == other.aspas_hash
}
pub fn same_content(&self, other: &Self) -> bool {
self.snapshot_hash == other.snapshot_hash
}
pub fn origins(&self) -> &BTreeSet<RouteOrigin> {
&self.origins
}
pub fn router_keys(&self) -> &BTreeSet<RouterKey> {
&self.router_keys
}
pub fn aspas(&self) -> &BTreeSet<Aspa> {
&self.aspas
}
}
impl Snapshot {
pub fn is_empty(&self) -> bool {
self.origins.is_empty()
&& self.router_keys.is_empty()
&& self.aspas.is_empty()
}
}
#[derive(Debug, Serialize, Deserialize)]
@ -180,7 +305,10 @@ pub struct Delta {
}
impl Delta {
pub fn new(serial: u32, announced: Vec<Payload>, withdrawn: Vec<Payload>) -> Self {
pub fn new(serial: u32, mut announced: Vec<Payload>, mut withdrawn: Vec<Payload>) -> Self {
sort_payloads_for_rtr(&mut announced, true);
sort_payloads_for_rtr(&mut withdrawn, false);
Delta {
serial,
announced,
@ -201,8 +329,8 @@ impl Delta {
&self.withdrawn
}
pub fn created_at(self) -> DualTime {
self.created_at
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
}
@ -219,7 +347,7 @@ pub struct RtrCache {
// Max number of deltas to keep.
max_delta: u8,
// Refresh interval.
refresh_interval: Duration,
timing: Timing,
// Last update begin time.
last_update_begin: DualTime,
// Last update end time.
@ -237,7 +365,7 @@ impl Default for RtrCache {
snapshot: Snapshot::empty(),
deltas: VecDeque::with_capacity(100),
max_delta: 100,
refresh_interval: Duration::from_secs(600),
timing: Timing::default(),
last_update_begin: now.clone(),
last_update_end: now.clone(),
created_at: now,
@ -248,9 +376,10 @@ impl Default for RtrCache {
pub struct RtrCacheBuilder {
session_id: Option<u16>,
max_delta: Option<u8>,
refresh_interval: Option<Duration>,
timing: Option<Timing>,
serial: Option<u32>,
snapshot: Option<Snapshot>,
deltas: Option<VecDeque<Arc<Delta>>>,
created_at: Option<DualTime>,
}
@ -259,9 +388,10 @@ impl RtrCacheBuilder {
Self {
session_id: None,
max_delta: None,
refresh_interval: None,
timing: None,
serial: None,
snapshot: None,
deltas: None,
created_at: None,
}
}
@ -276,8 +406,8 @@ impl RtrCacheBuilder {
self
}
pub fn refresh_interval(mut self, v: Duration) -> Self {
self.refresh_interval = Some(v);
pub fn timing(mut self, v: Timing) -> Self {
self.timing = Some(v);
self
}
@ -291,6 +421,11 @@ impl RtrCacheBuilder {
self
}
pub fn deltas(mut self, v: VecDeque<Arc<Delta>>) -> Self {
self.deltas = Some(v);
self
}
pub fn created_at(mut self, v: DualTime) -> Self {
self.created_at = Some(v);
self
@ -299,8 +434,12 @@ impl RtrCacheBuilder {
pub fn build(self) -> RtrCache {
let now = DualTime::now();
let max_delta = self.max_delta.unwrap_or(100);
let refresh_interval = self.refresh_interval.unwrap_or(Duration::from_secs(600));
let timing = self.timing.unwrap_or_default();
let snapshot = self.snapshot.unwrap_or_else(Snapshot::empty);
let deltas = self
.deltas
.unwrap_or_else(|| VecDeque::with_capacity(max_delta.into()));
let serial = self.serial.unwrap_or(0);
let created_at = self.created_at.unwrap_or_else(|| now.clone());
let session_id = self.session_id.unwrap_or_else(rand::random);
@ -309,9 +448,9 @@ impl RtrCacheBuilder {
session_id,
serial,
snapshot,
deltas: VecDeque::with_capacity(max_delta.into()),
deltas,
max_delta,
refresh_interval,
timing,
last_update_begin: now.clone(),
last_update_end: now,
created_at,
@ -325,42 +464,44 @@ impl RtrCache {
self,
store: &RtrStore,
max_delta: u8,
refresh_interval: Duration,
timing: Timing,
file_loader: impl Fn() -> anyhow::Result<Vec<Payload>>,
) -> anyhow::Result<Self> {
let snapshot = store.get_snapshot()?;
let session_id = store.get_session_id()?;
let serial = store.get_serial()?;
if let (Some(snapshot), Some(session_id), Some(serial)) =
(snapshot, session_id, serial)
{
let mut cache = RtrCacheBuilder::new()
.session_id(session_id)
.max_delta(max_delta)
.refresh_interval(refresh_interval)
.serial(serial)
.snapshot(snapshot)
.build();
if let Some((min_serial, _max_serial)) = store.get_delta_window()? {
let deltas = store.load_deltas_since(min_serial.wrapping_sub(1))?;
for delta in deltas {
cache.push_delta(Arc::new(delta));
}
}
if let Some(cache) = Self::try_restore_from_store(store, max_delta, timing)? {
tracing::info!(
"RTR cache restored from store: session_id={}, serial={}",
self.session_id,
self.serial
);
return Ok(cache);
}
tracing::warn!("RTR cache store unavailable or invalid, fallback to file loader");
let payloads = file_loader()?;
let snapshot = Snapshot::from_payloads(payloads);
if snapshot.is_empty() {
anyhow::bail!("file loader returned an empty snapshot");
}
tracing::info!(
"RTR cache initialized from file loader: session_id={}, serial={}",
self.session_id,
self.serial
);
let serial = 1;
let session_id: u16 = rand::random();
let snapshot_for_store = snapshot.clone();
let snapshot_for_cache = snapshot.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = store.save_snapshot_and_meta(&snapshot, session_id, serial) {
if let Err(e) =
store.save_snapshot_and_meta(&snapshot_for_store, session_id, serial)
{
tracing::error!("persist failed: {:?}", e);
}
});
@ -368,10 +509,65 @@ impl RtrCache {
Ok(RtrCacheBuilder::new()
.session_id(session_id)
.max_delta(max_delta)
.refresh_interval(refresh_interval)
.timing(timing)
.serial(serial)
.snapshot(snapshot_for_cache)
.build())
}
fn try_restore_from_store(
store: &RtrStore,
max_delta: u8,
timing: Timing,
) -> anyhow::Result<Option<Self>> {
let snapshot = store.get_snapshot()?;
let session_id = store.get_session_id()?;
let serial = store.get_serial()?;
let (snapshot, session_id, serial) = match (snapshot, session_id, serial) {
(Some(snapshot), Some(session_id), Some(serial)) => (snapshot, session_id, serial),
_ => {
tracing::warn!("RTR cache store incomplete: snapshot/session_id/serial missing");
return Ok(None);
}
};
if snapshot.is_empty() {
tracing::warn!("RTR cache store snapshot is empty, treat as unusable");
return Ok(None);
}
let mut cache = RtrCacheBuilder::new()
.session_id(session_id)
.max_delta(max_delta)
.timing(timing)
.serial(serial)
.snapshot(snapshot)
.build())
.build();
match store.get_delta_window()? {
Some((min_serial, _max_serial)) => {
let deltas = match store.load_deltas_since(min_serial.wrapping_sub(1)) {
Ok(deltas) => deltas,
Err(err) => {
tracing::warn!(
"RTR cache store delta recovery failed, treat store as unusable: {:?}",
err
);
return Ok(None);
}
};
for delta in deltas {
cache.push_delta(Arc::new(delta));
}
}
None => {
tracing::info!("RTR cache store has no delta window, restore snapshot only");
}
}
Ok(Some(cache))
}
fn next_serial(&mut self) -> u32 {
@ -431,10 +627,19 @@ impl RtrCache {
new_payloads: Vec<Payload>,
store: &RtrStore,
) -> anyhow::Result<()> {
self.last_update_begin = DualTime::now();
let new_snapshot = Snapshot::from_payloads(new_payloads);
if self.snapshot.same_content(&new_snapshot) {
self.last_update_end = DualTime::now();
return Ok(());
}
let (announced, withdrawn) = self.snapshot.diff(&new_snapshot);
if announced.is_empty() && withdrawn.is_empty() {
self.last_update_end = DualTime::now();
return Ok(());
}
@ -462,8 +667,8 @@ impl RtrCache {
self.serial
}
pub fn refresh_interval(&self) -> Duration {
self.refresh_interval
pub fn timing(&self) -> Timing {
self.timing
}
pub fn retry_interval(&self) -> Duration {
@ -477,6 +682,18 @@ impl RtrCache {
pub fn current_snapshot(&self) -> (&Snapshot, u32, u16) {
(&self.snapshot, self.serial, self.session_id)
}
pub fn last_update_begin(&self) -> DualTime {
self.last_update_begin.clone()
}
pub fn last_update_end(&self) -> DualTime {
self.last_update_end.clone()
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
}
impl RtrCache {
@ -524,6 +741,7 @@ impl RtrCache {
return SerialResult::UpToDate;
}
let _ = newest_serial;
SerialResult::Deltas(result)
}
}
@ -536,3 +754,176 @@ pub enum SerialResult {
/// Delta window cannot cover; reset required.
ResetRequired,
}
//------------ RTR ordering -------------------------------------------------
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
enum PayloadPduType {
Ipv4Prefix = 4,
Ipv6Prefix = 6,
RouterKey = 9,
Aspa = 11,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum RouteOriginKey {
V4 {
addr: u32,
plen: u8,
mlen: u8,
asn: u32,
},
V6 {
addr: u128,
plen: u8,
mlen: u8,
asn: u32,
},
}
fn sort_payloads_for_rtr(payloads: &mut [Payload], announce: bool) {
payloads.sort_by(|a, b| compare_payload_for_rtr(a, b, announce));
}
fn compare_payload_for_rtr(a: &Payload, b: &Payload, announce: bool) -> Ordering {
let type_a = payload_pdu_type(a);
let type_b = payload_pdu_type(b);
match type_a.cmp(&type_b) {
Ordering::Equal => {}
other => return other,
}
match (a, b) {
(Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => {
compare_route_origin_for_rtr(a, b, announce)
}
(Payload::RouterKey(a), Payload::RouterKey(b)) => {
compare_router_key_for_rtr(a, b)
}
(Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b),
_ => Ordering::Equal,
}
}
fn payload_pdu_type(payload: &Payload) -> PayloadPduType {
match payload {
Payload::RouteOrigin(ro) => {
if route_origin_is_ipv4(ro) {
PayloadPduType::Ipv4Prefix
} else {
PayloadPduType::Ipv6Prefix
}
}
Payload::RouterKey(_) => PayloadPduType::RouterKey,
Payload::Aspa(_) => PayloadPduType::Aspa,
}
}
fn route_origin_is_ipv4(ro: &RouteOrigin) -> bool {
ro.prefix().address.is_ipv4()
}
fn route_origin_key(ro: &RouteOrigin) -> RouteOriginKey {
let prefix = ro.prefix();
let plen = prefix.prefix_length;
let mlen = ro.max_length();
let asn = ro.asn().into_u32();
match prefix.address {
IPAddress::V4(addr) => {
RouteOriginKey::V4 {
addr: u32::from(addr),
plen,
mlen,
asn,
}
}
IPAddress::V6(addr) => {
RouteOriginKey::V6 {
addr: u128::from(addr),
plen,
mlen,
asn,
}
}
}
}
fn compare_route_origin_for_rtr(
a: &RouteOrigin,
b: &RouteOrigin,
announce: bool,
) -> Ordering {
match (route_origin_key(a), route_origin_key(b)) {
(
RouteOriginKey::V4 {
addr: addr_a,
plen: plen_a,
mlen: mlen_a,
asn: asn_a,
},
RouteOriginKey::V4 {
addr: addr_b,
plen: plen_b,
mlen: mlen_b,
asn: asn_b,
},
) => {
if announce {
addr_b.cmp(&addr_a)
.then_with(|| mlen_b.cmp(&mlen_a))
.then_with(|| plen_b.cmp(&plen_a))
.then_with(|| asn_b.cmp(&asn_a))
} else {
addr_a.cmp(&addr_b)
.then_with(|| mlen_a.cmp(&mlen_b))
.then_with(|| plen_a.cmp(&plen_b))
.then_with(|| asn_a.cmp(&asn_b))
}
}
(
RouteOriginKey::V6 {
addr: addr_a,
plen: plen_a,
mlen: mlen_a,
asn: asn_a,
},
RouteOriginKey::V6 {
addr: addr_b,
plen: plen_b,
mlen: mlen_b,
asn: asn_b,
},
) => {
if announce {
addr_b.cmp(&addr_a)
.then_with(|| mlen_b.cmp(&mlen_a))
.then_with(|| plen_b.cmp(&plen_a))
.then_with(|| asn_b.cmp(&asn_a))
} else {
addr_a.cmp(&addr_b)
.then_with(|| mlen_a.cmp(&mlen_b))
.then_with(|| plen_a.cmp(&plen_b))
.then_with(|| asn_a.cmp(&asn_b))
}
}
_ => Ordering::Equal,
}
}
fn compare_router_key_for_rtr(a: &RouterKey, b: &RouterKey) -> Ordering {
a.ski()
.cmp(&b.ski())
.then_with(|| a.spki().len().cmp(&b.spki().len()))
.then_with(|| a.spki().cmp(b.spki()))
.then_with(|| a.asn().into_u32().cmp(&b.asn().into_u32()))
}
fn compare_aspa_for_rtr(a: &Aspa, b: &Aspa) -> Ordering {
a.customer_asn()
.into_u32()
.cmp(&b.customer_asn().into_u32())
}

185
src/rtr/loader.rs Normal file
View File

@ -0,0 +1,185 @@
use std::fs;
use std::net::IpAddr;
use std::path::Path;
use std::str::FromStr;
use anyhow::{anyhow, Context, Result};
use crate::data_model::resources::as_resources::Asn;
use crate::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
use crate::rtr::payload::{Payload, RouteOrigin};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedVrp {
pub prefix_addr: IpAddr,
pub prefix_len: u8,
pub max_len: u8,
pub asn: u32,
}
/// 从文本文件中加载 VRP并转换成 RTR Payload::RouteOrigin。
///
/// 文件格式:
///
/// ```text
/// # prefix,max_len,asn
/// 10.0.0.0/24,24,65001
/// 10.0.1.0/24,24,65002
/// 2001:db8::/32,48,65003
/// ```
pub fn load_vrps_from_file(path: impl AsRef<Path>) -> Result<Vec<Payload>> {
let path = path.as_ref();
let content = fs::read_to_string(path)
.with_context(|| format!("failed to read VRP file: {}", path.display()))?;
let mut payloads = Vec::new();
for (idx, raw_line) in content.lines().enumerate() {
let line_no = idx + 1;
let line = raw_line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let vrp = parse_vrp_line(line)
.with_context(|| format!("invalid VRP line {}: {}", line_no, raw_line))?;
payloads.push(Payload::RouteOrigin(build_route_origin(vrp)?));
}
Ok(payloads)
}
/// 解析单行 VRP。
///
/// 格式:
/// `prefix/prefix_len,max_len,asn`
///
/// 例如:
/// `10.0.0.0/24,24,65001`
pub fn parse_vrp_line(line: &str) -> Result<ParsedVrp> {
let parts: Vec<_> = line.split(',').map(|s| s.trim()).collect();
if parts.len() != 3 {
return Err(anyhow!(
"expected format: <prefix>/<prefix_len>,<max_len>,<asn>"
));
}
let prefix_part = parts[0];
let max_len = u8::from_str(parts[1])
.with_context(|| format!("invalid max_len: {}", parts[1]))?;
let asn = u32::from_str(parts[2])
.with_context(|| format!("invalid asn: {}", parts[2]))?;
let (addr_str, prefix_len_str) = prefix_part
.split_once('/')
.ok_or_else(|| anyhow!("prefix must be in CIDR form, e.g. 10.0.0.0/24"))?;
let prefix_addr = IpAddr::from_str(addr_str.trim())
.with_context(|| format!("invalid IP address: {}", addr_str))?;
let prefix_len = u8::from_str(prefix_len_str.trim())
.with_context(|| format!("invalid prefix length: {}", prefix_len_str))?;
validate_vrp(prefix_addr, prefix_len, max_len)?;
Ok(ParsedVrp {
prefix_addr,
prefix_len,
max_len,
asn,
})
}
fn validate_vrp(prefix_addr: IpAddr, prefix_len: u8, max_len: u8) -> Result<()> {
match prefix_addr {
IpAddr::V4(_) => {
if prefix_len > 32 {
return Err(anyhow!("IPv4 prefix length must be <= 32"));
}
if max_len > 32 {
return Err(anyhow!("IPv4 max_len must be <= 32"));
}
if max_len < prefix_len {
return Err(anyhow!("IPv4 max_len must be >= prefix length"));
}
}
IpAddr::V6(_) => {
if prefix_len > 128 {
return Err(anyhow!("IPv6 prefix length must be <= 128"));
}
if max_len > 128 {
return Err(anyhow!("IPv6 max_len must be <= 128"));
}
if max_len < prefix_len {
return Err(anyhow!("IPv6 max_len must be >= prefix length"));
}
}
}
Ok(())
}
pub fn build_route_origin(vrp: ParsedVrp) -> Result<RouteOrigin> {
let address = match vrp.prefix_addr {
IpAddr::V4(addr) => IPAddress::from_ipv4(addr),
IpAddr::V6(addr) => IPAddress::from_ipv6(addr),
};
let prefix = IPAddressPrefix::new(address, vrp.prefix_len);
let asn = Asn::from(vrp.asn);
Ok(RouteOrigin::new(prefix, vrp.max_len, asn))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_ipv4_vrp_line() {
let got = parse_vrp_line("10.0.0.0/24,24,65001").unwrap();
assert_eq!(
got,
ParsedVrp {
prefix_addr: IpAddr::from_str("10.0.0.0").unwrap(),
prefix_len: 24,
max_len: 24,
asn: 65001,
}
);
}
#[test]
fn parse_ipv6_vrp_line() {
let got = parse_vrp_line("2001:db8::/32,48,65003").unwrap();
assert_eq!(
got,
ParsedVrp {
prefix_addr: IpAddr::from_str("2001:db8::").unwrap(),
prefix_len: 32,
max_len: 48,
asn: 65003,
}
);
}
#[test]
fn parse_rejects_invalid_max_len() {
let err = parse_vrp_line("10.0.0.0/24,16,65001").unwrap_err();
assert!(err.to_string().contains("max_len"));
}
#[test]
fn parse_rejects_invalid_ip() {
let err = parse_vrp_line("10.0.0.999/24,24,65001").unwrap_err();
assert!(err.to_string().contains("invalid IP"));
}
#[test]
fn parse_rejects_invalid_format() {
let err = parse_vrp_line("10.0.0.0/24,24").unwrap_err();
assert!(err.to_string().contains("expected format"));
}
}

View File

@ -1,7 +1,9 @@
pub mod pdu;
pub mod cache;
pub mod payload;
mod store_db;
mod session;
mod error_type;
mod state;
pub mod store_db;
pub mod session;
pub mod error_type;
pub mod state;
pub mod server;
pub mod loader;

View File

@ -1,13 +1,19 @@
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use asn1_rs::nom::character::streaming::u64;
use serde::{Deserialize, Serialize};
use crate::data_model::resources::as_resources::Asn;
use crate::data_model::resources::ip_resources::IPAddressPrefix;
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
enum PayloadPduType {
Ipv4Prefix = 4,
Ipv6Prefix = 6,
RouterKey = 9,
Aspa = 11,
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct Ski([u8; 20]);
@ -19,6 +25,14 @@ pub struct RouteOrigin {
}
impl RouteOrigin {
pub fn new(prefix: IPAddressPrefix, max_length: u8, asn: Asn) -> Self {
Self {
prefix,
max_length,
asn,
}
}
pub fn prefix(&self) -> &IPAddressPrefix {
&self.prefix
}
@ -37,13 +51,29 @@ impl RouteOrigin {
pub struct RouterKey {
subject_key_identifier: Ski,
asn: Asn,
subject_public_key_info: Arc<[u8]>,
subject_public_key_info: Vec<u8>,
}
impl RouterKey {
pub fn new(subject_key_identifier: Ski, asn: Asn, subject_public_key_info: Vec<u8>) -> Self {
Self {
subject_key_identifier,
asn,
subject_public_key_info,
}
}
pub fn ski(&self) -> Ski {
self.subject_key_identifier
}
pub fn asn(&self) -> Asn {
self.asn
}
pub fn spki(&self) -> &[u8] {
&self.subject_public_key_info
}
}
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
@ -53,6 +83,16 @@ pub struct Aspa {
}
impl Aspa {
pub fn new(customer_asn: Asn, mut provider_asns: Vec<Asn>) -> Self {
provider_asns.sort();
provider_asns.dedup();
Self {
customer_asn,
provider_asns,
}
}
pub fn customer_asn(&self) -> Asn {
self.customer_asn
}
@ -66,7 +106,7 @@ impl Aspa {
#[derive(Clone, Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum Payload {
/// A route origin authorisation.
/// A route origin.
RouteOrigin(RouteOrigin),
/// A BGPsec router key.
@ -91,6 +131,10 @@ pub struct Timing {
}
impl Timing {
pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self {
Self { refresh, retry, expire }
}
pub fn refresh(self) -> Duration {
Duration::from_secs(u64::from(self.refresh))
}
@ -104,3 +148,13 @@ impl Timing {
}
}
impl Default for Timing {
fn default() -> Self {
Self {
refresh: 3600,
retry: 600,
expire: 7200,
}
}
}

View File

@ -10,10 +10,11 @@ use anyhow::Result;
use std::slice;
use anyhow::bail;
use serde::Serialize;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
pub const HEADER_LEN: u32 = 8;
pub const HEADER_LEN: usize = 8;
pub const MAX_PDU_LEN: u32 = 65535;
pub const IPV4_PREFIX_LEN: u32 = 20;
pub const IPV6_PREFIX_LEN: u32 = 32;
@ -78,6 +79,10 @@ macro_rules! concrete {
self.header.session_id()
}
pub fn pdu(&self) -> u8 {
self.header.pdu()
}
/// Returns the PDU size.
///
/// The size is returned as a `u32` since that type is used in
@ -126,7 +131,7 @@ macro_rules! concrete {
) -> Result<Result<Self, Header>, io::Error> {
let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() == Error::PDU {
if res.header.pdu() == ErrorReport::PDU {
// Since we should drop the session after an error, we
// can safely ignore all the rest of the error for now.
return Ok(Err(res.header))
@ -183,7 +188,7 @@ macro_rules! concrete {
// 所有PDU公共头部信息
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Header {
version: u8,
pdu: u8,
@ -203,36 +208,34 @@ impl Header {
}
}
pub async fn read(sock: &mut TcpStream) -> Result<Self> {
pub async fn read<S: AsyncRead + Unpin>(sock: &mut S) -> Result<Self, io::Error> {
let mut buf = [0u8; HEADER_LEN];
// 1. 精确读取 8 字节
sock.read_exact(&mut buf).await?;
// 2. 手动解析(大端)
let version = buf[0];
let pdu = buf[1];
let reserved = u16::from_be_bytes([buf[2], buf[3]]);
let length = u32::from_be_bytes([
buf[4], buf[5], buf[6], buf[7],
]);
let session_id = u16::from_be_bytes([buf[2], buf[3]]);
let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
// 3. 基础合法性校验
if length < HEADER_LEN{
bail!("Invalid PDU length");
if length < HEADER_LEN as u32 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid PDU length",
));
}
// 限制最大长度
if length > MAX_PDU_LEN {
bail!("PDU too large");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU too large",
));
}
Ok(Self {
version,
pdu,
session_id: reserved,
length,
session_id: session_id.to_be(),
length: length.to_be(),
})
}
@ -259,7 +262,7 @@ impl Header {
common!(Header);
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct HeaderWithFlags {
version: u8,
pdu: u8,
@ -294,7 +297,7 @@ impl HeaderWithFlags {
]);
// 3. 基础合法性校验
if length < HEADER_LEN{
if length < HEADER_LEN as u32{
bail!("Invalid PDU length");
}
@ -324,7 +327,7 @@ impl HeaderWithFlags {
// Serial Notify
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct SerialNotify {
header: Header,
serial_number: u32,
@ -340,13 +343,17 @@ impl SerialNotify {
}
}
pub fn serial_number(self) -> u32 {
self.serial_number
}
}
concrete!(SerialNotify);
// Serial Query
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct SerialQuery {
header: Header,
serial_number: u32,
@ -372,7 +379,7 @@ concrete!(SerialQuery);
// Reset Query
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct ResetQuery {
header: Header
}
@ -382,7 +389,7 @@ impl ResetQuery {
pub fn new(version: u8) -> Self {
ResetQuery {
header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN),
header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32),
}
}
}
@ -392,7 +399,7 @@ concrete!(ResetQuery);
// Cache Response
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct CacheResponse {
header: Header,
}
@ -402,7 +409,7 @@ impl CacheResponse {
pub fn new(version: u8, session_id: u16) -> Self {
CacheResponse {
header: Header::new(version, Self::PDU, session_id, HEADER_LEN),
header: Header::new(version, Self::PDU, session_id, HEADER_LEN as u32),
}
}
}
@ -411,7 +418,7 @@ concrete!(CacheResponse);
// Flags
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Flags(u8);
impl Flags {
@ -434,7 +441,7 @@ impl Flags {
// IPv4 Prefix
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct IPv4Prefix {
header: Header,
@ -479,7 +486,7 @@ concrete!(IPv4Prefix);
// IPv6 Prefix
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct IPv6Prefix {
header: Header,
@ -524,7 +531,7 @@ concrete!(IPv6Prefix);
// End of Data
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize)]
pub enum EndOfData {
V0(EndOfDataV0),
V1(EndOfDataV1),
@ -544,7 +551,7 @@ impl EndOfData {
}
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct EndOfDataV0 {
header: Header,
@ -567,7 +574,7 @@ concrete!(EndOfDataV0);
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct EndOfDataV1 {
header: Header,
@ -605,7 +612,7 @@ concrete!(EndOfDataV1);
// Cache Reset
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct CacheReset {
header: Header,
}
@ -615,7 +622,7 @@ impl CacheReset {
pub fn new(version: u8) -> Self{
CacheReset {
header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN)
header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32)
}
}
}
@ -624,7 +631,7 @@ concrete!(CacheReset);
// Error Report
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct ErrorReport {
octets: Vec<u8>,
}
@ -703,7 +710,7 @@ impl ErrorReport {
// TODO: 补全
// Router Key
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct RouterKey {
header: HeaderWithFlags,
@ -727,13 +734,13 @@ impl RouterKey {
+ 1 // flags
// + self.ski.as_ref().len()
+ 4 // ASN
+ self.subject_public_key_info.len() as u32;
+ self.subject_public_key_info.len();
let header = HeaderWithFlags::new(
self.header.version(),
Self::PDU,
self.flags,
length,
length as u32,
);
w.write_all(&[
@ -755,6 +762,7 @@ impl RouterKey {
// ASPA
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Aspa{
header: HeaderWithFlags,
@ -773,13 +781,13 @@ impl Aspa {
let length = HEADER_LEN
+ 1
+ 4
+ (self.provider_asns.len() as u32 * 4);
+ (self.provider_asns.len() * 4);
let header = HeaderWithFlags::new(
self.header.version(),
Self::PDU,
Flags::new(self.header.flags),
length,
length as u32,
);
w.write_all(&[
@ -836,7 +844,7 @@ mod tests {
assert_eq!(decoded.version(), 1);
assert_eq!(decoded.session_id(), 42);
assert_eq!(decoded.serial_number, 100u32.to_be());
assert_eq!(decoded.serial_number(), 100u32.to_be());
}
#[tokio::test]

14
src/rtr/server/config.rs Normal file
View File

@ -0,0 +1,14 @@
#[derive(Debug, Clone)]
pub struct RtrServiceConfig {
pub max_connections: usize,
pub notify_queue_size: usize,
}
impl Default for RtrServiceConfig {
fn default() -> Self {
Self {
max_connections: 1024,
notify_queue_size: 1024,
}
}
}

View File

@ -0,0 +1,79 @@
use std::net::SocketAddr;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use anyhow::{Context, Result};
use tokio::net::TcpStream;
use tokio::sync::{broadcast, watch, OwnedSemaphorePermit};
use tracing::error;
use tokio_rustls::TlsAcceptor;
use crate::rtr::cache::SharedRtrCache;
use crate::rtr::session::RtrSession;
pub struct ConnectionGuard {
active_connections: Arc<AtomicUsize>,
_permit: OwnedSemaphorePermit,
}
impl ConnectionGuard {
pub fn new(
active_connections: Arc<AtomicUsize>,
permit: OwnedSemaphorePermit,
) -> Self {
active_connections.fetch_add(1, Ordering::Relaxed);
Self {
active_connections,
_permit: permit,
}
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
}
}
pub async fn handle_tcp_connection(
cache: SharedRtrCache,
stream: TcpStream,
peer_addr: SocketAddr,
notify_rx: broadcast::Receiver<()>,
shutdown_rx: watch::Receiver<bool>,
) -> Result<()> {
let session = RtrSession::new(cache, stream, notify_rx, shutdown_rx);
if let Err(err) = session.run().await {
error!("RTR TCP session run failed for {}: {:?}", peer_addr, err);
return Err(err);
}
Ok(())
}
pub async fn handle_tls_connection(
cache: SharedRtrCache,
stream: TcpStream,
peer_addr: SocketAddr,
acceptor: TlsAcceptor,
notify_rx: broadcast::Receiver<()>,
shutdown_rx: watch::Receiver<bool>,
) -> Result<()> {
let tls_stream = acceptor
.accept(stream)
.await
.with_context(|| format!("TLS handshake failed for {}", peer_addr))?;
let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx);
if let Err(err) = session.run().await {
error!("RTR TLS session run failed for {}: {:?}", peer_addr, err);
return Err(err);
}
Ok(())
}

219
src/rtr/server/listener.rs Normal file
View File

@ -0,0 +1,219 @@
use std::net::SocketAddr;
use std::path::Path;
use std::sync::{
Arc,
atomic::AtomicUsize,
};
use anyhow::{Context, Result};
use tokio::net::TcpListener;
use tokio::sync::{broadcast, watch, Semaphore};
use tracing::{info, warn};
use rustls::ServerConfig;
use tokio_rustls::TlsAcceptor;
use crate::rtr::cache::SharedRtrCache;
use crate::rtr::server::connection::{ConnectionGuard, handle_tcp_connection, handle_tls_connection};
use crate::rtr::server::tls::load_rustls_server_config;
pub struct RtrServer {
bind_addr: SocketAddr,
cache: SharedRtrCache,
notify_tx: broadcast::Sender<()>,
shutdown_tx: watch::Sender<bool>,
connection_limiter: Arc<Semaphore>,
active_connections: Arc<AtomicUsize>,
}
impl RtrServer {
pub fn new(
bind_addr: SocketAddr,
cache: SharedRtrCache,
notify_tx: broadcast::Sender<()>,
shutdown_tx: watch::Sender<bool>,
connection_limiter: Arc<Semaphore>,
active_connections: Arc<AtomicUsize>,
) -> Self {
Self {
bind_addr,
cache,
notify_tx,
shutdown_tx,
connection_limiter,
active_connections,
}
}
pub fn bind_addr(&self) -> SocketAddr {
self.bind_addr
}
pub fn cache(&self) -> SharedRtrCache {
self.cache.clone()
}
pub fn active_connections(&self) -> usize {
self.active_connections.load(std::sync::atomic::Ordering::Relaxed)
}
pub async fn run_tcp(self) -> Result<()> {
let listener = TcpListener::bind(self.bind_addr)
.await
.with_context(|| format!("failed to bind TCP RTR server on {}", self.bind_addr))?;
let mut shutdown_rx = self.shutdown_tx.subscribe();
info!("RTR TCP server listening on {}", self.bind_addr);
loop {
tokio::select! {
changed = shutdown_rx.changed() => {
match changed {
Ok(()) => {
if *shutdown_rx.borrow() {
info!("RTR TCP listener {} shutting down", self.bind_addr);
return Ok(());
}
}
Err(_) => {
info!("RTR TCP listener {} shutdown channel closed", self.bind_addr);
return Ok(());
}
}
}
accept_res = listener.accept() => {
let (stream, peer_addr) = match accept_res {
Ok(v) => v,
Err(err) => {
warn!("RTR TCP accept failed: {}", err);
continue;
}
};
if let Err(err) = stream.set_nodelay(true) {
warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err);
}
let permit = match self.connection_limiter.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
warn!(
"RTR TCP connection rejected for {}: max connections reached ({})",
peer_addr,
self.connection_limiter.available_permits()
);
drop(stream);
continue;
}
};
let cache = self.cache.clone();
let notify_rx = self.notify_tx.subscribe();
let shutdown_rx = self.shutdown_tx.subscribe();
let active_connections = self.active_connections.clone();
info!("RTR TCP client connected: {}", peer_addr);
tokio::spawn(async move {
let _guard = ConnectionGuard::new(active_connections, permit);
if let Err(err) =
handle_tcp_connection(cache, stream, peer_addr, notify_rx, shutdown_rx).await
{
warn!("RTR TCP session {} ended with error: {:?}", peer_addr, err);
} else {
info!("RTR TCP session {} closed", peer_addr);
}
});
}
}
}
}
pub async fn run_tls_from_pem(
self,
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
) -> Result<()> {
let tls_config = Arc::new(load_rustls_server_config(cert_path, key_path)?);
self.run_tls(tls_config).await
}
pub async fn run_tls(self, tls_config: Arc<ServerConfig>) -> Result<()> {
let listener = TcpListener::bind(self.bind_addr)
.await
.with_context(|| format!("failed to bind TLS RTR server on {}", self.bind_addr))?;
let acceptor = TlsAcceptor::from(tls_config);
let mut shutdown_rx = self.shutdown_tx.subscribe();
info!("RTR TLS server listening on {}", self.bind_addr);
loop {
tokio::select! {
changed = shutdown_rx.changed() => {
match changed {
Ok(()) => {
if *shutdown_rx.borrow() {
info!("RTR TLS listener {} shutting down", self.bind_addr);
return Ok(());
}
}
Err(_) => {
info!("RTR TLS listener {} shutdown channel closed", self.bind_addr);
return Ok(());
}
}
}
accept_res = listener.accept() => {
let (stream, peer_addr) = match accept_res {
Ok(v) => v,
Err(err) => {
warn!("RTR TLS accept failed: {}", err);
continue;
}
};
if let Err(err) = stream.set_nodelay(true) {
warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err);
}
let permit = match self.connection_limiter.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
warn!("RTR TLS connection rejected for {}: max connections reached", peer_addr);
drop(stream);
continue;
}
};
let cache = self.cache.clone();
let acceptor = acceptor.clone();
let notify_rx = self.notify_tx.subscribe();
let shutdown_rx = self.shutdown_tx.subscribe();
let active_connections = self.active_connections.clone();
info!("RTR TLS client connected: {}", peer_addr);
tokio::spawn(async move {
let _guard = ConnectionGuard::new(active_connections, permit);
if let Err(err) = handle_tls_connection(
cache,
stream,
peer_addr,
acceptor,
notify_rx,
shutdown_rx,
).await {
warn!("RTR TLS session {} ended with error: {:?}", peer_addr, err);
} else {
info!("RTR TLS session {} closed", peer_addr);
}
});
}
}
}
}
}

12
src/rtr/server/mod.rs Normal file
View File

@ -0,0 +1,12 @@
pub mod config;
pub mod connection;
pub mod listener;
pub mod notifier;
pub mod service;
pub mod tls;
pub use config::RtrServiceConfig;
pub use listener::RtrServer;
pub use notifier::RtrNotifier;
pub use service::{RtrService, RunningRtrService};
pub use tls::load_rustls_server_config;

View File

@ -0,0 +1,16 @@
use tokio::sync::broadcast;
#[derive(Clone)]
pub struct RtrNotifier {
tx: broadcast::Sender<()>,
}
impl RtrNotifier {
pub fn new(tx: broadcast::Sender<()>) -> Self {
Self { tx }
}
pub fn notify_cache_updated(&self) {
let _ = self.tx.send(());
}
}

154
src/rtr/server/service.rs Normal file
View File

@ -0,0 +1,154 @@
use std::net::SocketAddr;
use std::path::Path;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use tokio::sync::{broadcast, watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::error;
use crate::rtr::cache::SharedRtrCache;
use crate::rtr::server::config::RtrServiceConfig;
use crate::rtr::server::listener::RtrServer;
use crate::rtr::server::notifier::RtrNotifier;
pub struct RtrService {
cache: SharedRtrCache,
notify_tx: broadcast::Sender<()>,
shutdown_tx: watch::Sender<bool>,
connection_limiter: Arc<Semaphore>,
active_connections: Arc<AtomicUsize>,
config: RtrServiceConfig,
}
impl RtrService {
pub fn new(cache: SharedRtrCache) -> Self {
Self::with_config(cache, RtrServiceConfig::default())
}
pub fn with_config(cache: SharedRtrCache, config: RtrServiceConfig) -> Self {
let (notify_tx, _) = broadcast::channel(config.notify_queue_size);
let (shutdown_tx, _) = watch::channel(false);
Self {
cache,
notify_tx,
shutdown_tx,
connection_limiter: Arc::new(Semaphore::new(config.max_connections)),
active_connections: Arc::new(AtomicUsize::new(0)),
config,
}
}
pub fn cache(&self) -> SharedRtrCache {
self.cache.clone()
}
pub fn notifier(&self) -> RtrNotifier {
RtrNotifier::new(self.notify_tx.clone())
}
pub fn notify_cache_updated(&self) {
let _ = self.notify_tx.send(());
}
pub fn active_connections(&self) -> usize {
self.active_connections.load(Ordering::Relaxed)
}
pub fn max_connections(&self) -> usize {
self.config.max_connections
}
pub fn tcp_server(&self, bind_addr: SocketAddr) -> RtrServer {
RtrServer::new(
bind_addr,
self.cache.clone(),
self.notify_tx.clone(),
self.shutdown_tx.clone(),
self.connection_limiter.clone(),
self.active_connections.clone(),
)
}
pub fn tls_server(&self, bind_addr: SocketAddr) -> RtrServer {
RtrServer::new(
bind_addr,
self.cache.clone(),
self.notify_tx.clone(),
self.shutdown_tx.clone(),
self.connection_limiter.clone(),
self.active_connections.clone(),
)
}
pub fn spawn_tcp(&self, bind_addr: SocketAddr) -> JoinHandle<()> {
let server = self.tcp_server(bind_addr);
tokio::spawn(async move {
if let Err(err) = server.run_tcp().await {
error!("RTR TCP server {} exited with error: {:?}", bind_addr, err);
}
})
}
pub fn spawn_tls_from_pem(
&self,
bind_addr: SocketAddr,
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
) -> JoinHandle<()> {
let cert_path = cert_path.as_ref().to_path_buf();
let key_path = key_path.as_ref().to_path_buf();
let server = self.tls_server(bind_addr);
tokio::spawn(async move {
if let Err(err) = server.run_tls_from_pem(cert_path, key_path).await {
error!("RTR TLS server {} exited with error: {:?}", bind_addr, err);
}
})
}
pub fn spawn_tcp_and_tls_from_pem(
&self,
tcp_bind_addr: SocketAddr,
tls_bind_addr: SocketAddr,
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
) -> RunningRtrService {
let tcp_handle = self.spawn_tcp(tcp_bind_addr);
let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path);
RunningRtrService {
shutdown_tx: self.shutdown_tx.clone(),
handles: vec![tcp_handle, tls_handle],
}
}
pub fn spawn_tcp_only(&self, tcp_bind_addr: SocketAddr) -> RunningRtrService {
let tcp_handle = self.spawn_tcp(tcp_bind_addr);
RunningRtrService {
shutdown_tx: self.shutdown_tx.clone(),
handles: vec![tcp_handle],
}
}
}
pub struct RunningRtrService {
shutdown_tx: watch::Sender<bool>,
handles: Vec<JoinHandle<()>>,
}
impl RunningRtrService {
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(true);
}
pub async fn wait(self) {
for handle in self.handles {
let _ = handle.await;
}
}
}

52
src/rtr/server/tls.rs Normal file
View File

@ -0,0 +1,52 @@
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use anyhow::{anyhow, Context, Result};
use rustls::ServerConfig;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
pub fn load_rustls_server_config(
cert_path: impl AsRef<Path>,
key_path: impl AsRef<Path>,
) -> Result<ServerConfig> {
let cert_path: PathBuf = cert_path.as_ref().to_path_buf();
let key_path: PathBuf = key_path.as_ref().to_path_buf();
let certs = load_certs(&cert_path)
.with_context(|| format!("failed to load certs from {}", cert_path.display()))?;
let key = load_private_key(&key_path)
.with_context(|| format!("failed to load private key from {}", key_path.display()))?;
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| anyhow!("invalid certificate/key pair: {}", e))?;
Ok(config)
}
fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err(anyhow!("no certificates found in {}", path.display()));
}
Ok(certs)
}
fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or_else(|| anyhow!("no private key found in {}", path.display()))?;
Ok(key)
}

View File

@ -1,16 +1,16 @@
use std::sync::Arc;
use anyhow::{bail, Result};
use tokio::io;
use tokio::net::TcpStream;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::{broadcast, watch};
use tracing::warn;
use crate::rtr::cache::{Delta, RtrCache, SerialResult};
use crate::data_model::resources::ip_resources::IPAddress;
use crate::rtr::cache::{Delta, SerialResult, SharedRtrCache};
use crate::rtr::error_type::ErrorCode;
use crate::rtr::payload::{Payload, RouteOrigin, Timing};
use crate::rtr::payload::{Payload, RouteOrigin};
use crate::rtr::pdu::{
CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, Header, IPv4Prefix, IPv6Prefix,
ResetQuery, SerialQuery,
ResetQuery, SerialNotify, SerialQuery,
};
const SUPPORTED_MAX_VERSION: u8 = 2;
@ -26,34 +26,68 @@ enum SessionState {
Closed,
}
pub struct RtrSession {
cache: Arc<RtrCache>,
pub struct RtrSession<S> {
cache: SharedRtrCache,
version: Option<u8>,
stream: TcpStream,
stream: S,
state: SessionState,
notify_rx: broadcast::Receiver<()>,
shutdown_rx: watch::Receiver<bool>,
}
impl RtrSession {
pub fn new(cache: Arc<RtrCache>, stream: TcpStream) -> Self {
impl<S> RtrSession<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(
cache: SharedRtrCache,
stream: S,
notify_rx: broadcast::Receiver<()>,
shutdown_rx: watch::Receiver<bool>,
) -> Self {
Self {
cache,
version: None,
stream,
state: SessionState::Connected,
notify_rx,
shutdown_rx,
}
}
pub async fn run(mut self) -> Result<()> {
loop {
let header = match Header::read(&mut self.stream).await {
tokio::select! {
changed = self.shutdown_rx.changed() => {
match changed {
Ok(()) => {
if *self.shutdown_rx.borrow() {
self.state = SessionState::Closed;
return Ok(());
}
}
Err(_) => {
// shutdown sender dropped按关闭处理
self.state = SessionState::Closed;
return Ok(());
}
}
}
header_res = Header::read(&mut self.stream) => {
let header = match header_res {
Ok(h) => h,
Err(_) => return Ok(()),
Err(_) => {
self.state = SessionState::Closed;
return Ok(());
}
};
if self.version.is_none() {
self.negotiate_version(header.version()).await?;
} else if header.version() != self.version.unwrap() {
self.send_unsupported_version(self.version.unwrap()).await?;
self.state = SessionState::Closed;
bail!("version changed within session");
}
@ -74,12 +108,35 @@ impl RtrSession {
return Ok(());
}
_ => {
self.send_error(header.version(), ErrorCode::UnsupportedPduType, Some(&header), &[])
self.send_error(
header.version(),
ErrorCode::UnsupportedPduType,
Some(&header),
&[],
)
.await?;
self.state = SessionState::Closed;
return Ok(());
}
}
}
notify_res = self.notify_rx.recv(),
if self.state == SessionState::Established && self.version.is_some() => {
match notify_res {
Ok(()) => {
self.handle_notify().await?;
}
Err(broadcast::error::RecvError::Lagged(_)) => {
self.handle_notify().await?;
}
Err(broadcast::error::RecvError::Closed) => {
// notify 通道关闭,不影响已有会话,继续跑,真正关闭由 shutdown_rx 控制
}
}
}
}
}
}
async fn negotiate_version(&mut self, router_version: u8) -> io::Result<u8> {
@ -115,33 +172,65 @@ impl RtrSession {
}
async fn handle_reset_query(&mut self) -> Result<()> {
let (payloads, session_id, serial) = {
let cache = self
.cache
.read()
.map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?;
let snapshot = cache.snapshot();
let payloads = snapshot.payloads_for_rtr();
let session_id = cache.session_id();
let serial = cache.serial();
(payloads, session_id, serial)
};
self.write_cache_response(session_id).await?;
self.send_payloads(&payloads, true).await?;
self.write_end_of_data(session_id, serial).await?;
self.state = SessionState::Established;
let snapshot = self.cache.snapshot();
self.write_cache_response().await?;
self.send_payloads(snapshot.payloads(), true).await?;
self.write_end_of_data(self.cache.session_id(), self.cache.serial())
.await?;
Ok(())
}
async fn handle_serial(&mut self, client_session: u16, client_serial: u32) -> Result<()> {
let current_session = self.cache.session_id();
let current_serial = self.cache.serial();
let serial_result = {
let cache = self
.cache
.read()
.map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?;
cache.get_deltas_since(client_session, client_serial)
};
match self.cache.get_deltas_since(client_session, client_serial) {
match serial_result {
SerialResult::ResetRequired => {
self.write_cache_reset().await?;
self.state = SessionState::Established;
return Ok(());
}
SerialResult::UpToDate => {
let (current_session, current_serial) = {
let cache = self
.cache
.read()
.map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?;
(cache.session_id(), cache.serial())
};
self.write_end_of_data(current_session, current_serial)
.await?;
self.state = SessionState::Established;
return Ok(());
}
SerialResult::Deltas(deltas) => {
self.write_cache_response().await?;
let (current_session, current_serial) = {
let cache = self
.cache
.read()
.map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?;
(cache.session_id(), cache.serial())
};
self.write_cache_response(current_session).await?;
for delta in deltas {
self.send_delta(&delta).await?;
}
@ -154,32 +243,55 @@ impl RtrSession {
Ok(())
}
async fn write_cache_response(&mut self) -> Result<()> {
let version = self.version.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "version not negotiated")
})?;
async fn handle_notify(&mut self) -> Result<()> {
let (session_id, serial) = {
let cache = self
.cache
.read()
.map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?;
(cache.session_id(), cache.serial())
};
CacheResponse::new(version, self.cache.session_id())
self.send_serial_notify(session_id, serial).await
}
fn version(&self) -> io::Result<u8> {
self.version
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "version not negotiated"))
}
async fn send_serial_notify(&mut self, session_id: u16, serial: u32) -> Result<()> {
let version = self.version()?;
SerialNotify::new(version, session_id, serial)
.write(&mut self.stream)
.await?;
Ok(())
}
async fn write_cache_response(&mut self, session_id: u16) -> Result<()> {
let version = self.version()?;
CacheResponse::new(version, session_id)
.write(&mut self.stream)
.await?;
Ok(())
}
async fn write_cache_reset(&mut self) -> Result<()> {
let version = self.version.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "version not negotiated")
})?;
let version = self.version()?;
CacheReset::new(version).write(&mut self.stream).await?;
Ok(())
}
async fn write_end_of_data(&mut self, session_id: u16, serial: u32) -> Result<()> {
let version = self.version.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "version not negotiated")
})?;
let version = self.version()?;
let timing = {
let cache = self
.cache
.read()
.map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?;
cache.timing()
};
let timing = self.timing();
let end = EndOfData::new(version, session_id, serial, timing);
match end {
EndOfData::V0(pdu) => pdu.write(&mut self.stream).await?,
@ -189,20 +301,20 @@ impl RtrSession {
Ok(())
}
async fn send_payloads(&mut self, payloads: Vec<Payload>, announce: bool) -> Result<()> {
async fn send_payloads(&mut self, payloads: &[Payload], announce: bool) -> Result<()> {
for payload in payloads {
self.send_payload(&payload, announce).await?;
self.send_payload(payload, announce).await?;
}
Ok(())
}
async fn send_delta(&mut self, delta: &Arc<Delta>) -> Result<()> {
for payload in delta.withdrawn() {
self.send_payload(payload, false).await?;
}
async fn send_delta(&mut self, delta: &Delta) -> Result<()> {
for payload in delta.announced() {
self.send_payload(payload, true).await?;
}
for payload in delta.withdrawn() {
self.send_payload(payload, false).await?;
}
Ok(())
}
@ -222,9 +334,7 @@ impl RtrSession {
}
async fn send_route_origin(&mut self, origin: &RouteOrigin, announce: bool) -> Result<()> {
let version = self.version.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "version not negotiated")
})?;
let version = self.version()?;
let flags = Flags::new(if announce {
ANNOUNCE_FLAG
@ -236,16 +346,18 @@ impl RtrSession {
let prefix_len = prefix.prefix_length;
let max_len = origin.max_length();
if let Some(v4) = prefix.address.to_ipv4() {
match prefix.address {
IPAddress::V4(v4) => {
IPv4Prefix::new(version, flags, prefix_len, max_len, v4, origin.asn())
.write(&mut self.stream)
.await?;
} else {
let v6 = prefix.address.to_ipv6();
}
IPAddress::V6(v6) => {
IPv6Prefix::new(version, flags, prefix_len, max_len, v6, origin.asn())
.write(&mut self.stream)
.await?;
}
}
Ok(())
}
@ -257,24 +369,10 @@ impl RtrSession {
offending_header: Option<&Header>,
text: &[u8],
) -> io::Result<()> {
let offending = offending_header
.map(|h| h.as_ref())
.unwrap_or(&[]);
let offending = offending_header.map(|h| h.as_ref()).unwrap_or(&[]);
ErrorReport::new(version, code.as_u16(), offending, text)
.write(&mut self.stream)
.await
}
fn timing(&self) -> Timing {
let refresh = self.cache.refresh_interval().as_secs();
let retry = self.cache.retry_interval().as_secs();
let expire = self.cache.expire_interval().as_secs();
Timing {
refresh: refresh.min(u32::MAX as u64) as u32,
retry: retry.min(u32::MAX as u64) as u32,
expire: expire.min(u32::MAX as u64) as u32,
}
}
}

View File

@ -274,20 +274,30 @@ impl RtrStore {
}
pub fn load_deltas_since(&self, serial: u32) -> Result<Vec<Delta>> {
let cf_handle = self.db.cf_handle(CF_DELTA).ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let mut out = Vec::new();
let start_key = delta_key(serial.wrapping_add(1));
let iter = self
let cf_handle = self
.db
.iterator_cf(cf_handle, IteratorMode::From(&start_key, Direction::Forward));
.cf_handle(CF_DELTA)
.ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let start_key = delta_key(serial.wrapping_add(1));
let iter = self.db.iterator_cf(
cf_handle,
IteratorMode::From(&start_key, Direction::Forward),
);
let mut out = Vec::new();
for item in iter {
let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let parsed = delta_key_serial(key.as_ref())
.ok_or_else(|| anyhow!("Invalid delta key"))?;
for (key, value) in iter {
let parsed = delta_key_serial(&key).ok_or_else(|| anyhow!("Invalid delta key"))?;
if parsed <= serial {
continue;
}
let delta: Delta = serde_json::from_slice(&value)?;
let delta: Delta = serde_json::from_slice(value.as_ref())?;
out.push(delta);
}

1
tests/common/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod test_helper;

332
tests/common/test_helper.rs Normal file
View File

@ -0,0 +1,332 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use std::fmt::Write;
use serde_json::{json, Value};
use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
use rpki::rtr::payload::{Payload, RouteOrigin};
use rpki::rtr::pdu::{CacheResponse, EndOfDataV1, IPv4Prefix, IPv6Prefix};
use rpki::rtr::cache::SerialResult;
pub struct RtrDebugDumper {
entries: Vec<Value>,
}
impl RtrDebugDumper {
pub fn new() -> Self {
Self { entries: Vec::new() }
}
pub fn push<T: serde::Serialize>(&mut self, pdu: u8, body: &T) {
self.entries.push(json!({
"pdu": pdu,
"pdu_name": pdu_type_name(pdu),
"body": body
}));
}
pub fn push_value(&mut self, pdu: u8, body: Value) {
self.entries.push(json!({
"pdu": pdu,
"pdu_name": pdu_type_name(pdu),
"body": body
}));
}
pub fn print_pretty(&self, test_name: &str) {
println!(
"\n===== RTR Debug Dump: {} =====\n{}\n",
test_name,
serde_json::to_string_pretty(&self.entries).unwrap()
);
}
}
pub fn pdu_type_name(pdu: u8) -> &'static str {
match pdu {
0 => "Serial Notify",
1 => "Serial Query",
2 => "Reset Query",
3 => "Cache Response",
4 => "IPv4 Prefix",
6 => "IPv6 Prefix",
7 => "End of Data",
8 => "Cache Reset",
9 => "Router Key",
10 => "Error Report",
11 => "ASPA",
255 => "Reserved",
_ => "Unknown",
}
}
pub fn dump_cache_response(resp: &CacheResponse) -> Value {
json!({
"header": {
"version": resp.version(),
"pdu": resp.pdu(),
"session_id": resp.session_id(),
"length": 8
}
})
}
pub fn dump_ipv4_prefix(p: &IPv4Prefix) -> Value {
json!({
"header": {
"version": p.version(),
"pdu": p.pdu(),
"session_id": 0,
"length": 20
},
"flags": {
"raw": if p.flag().is_announce() { 1 } else { 0 },
"announce": p.flag().is_announce()
},
"prefix": p.prefix().to_string(),
"prefix_len": p.prefix_len(),
"max_len": p.max_len(),
"asn": p.asn().into_u32()
})
}
pub fn dump_ipv6_prefix(p: &IPv6Prefix) -> Value {
json!({
"header": {
"version": p.version(),
"pdu": p.pdu(),
"session_id": 0,
"length": 32
},
"flags": {
"raw": if p.flag().is_announce() { 1 } else { 0 },
"announce": p.flag().is_announce()
},
"prefix": p.prefix().to_string(),
"prefix_len": p.prefix_len(),
"max_len": p.max_len(),
"asn": p.asn().into_u32()
})
}
pub fn dump_eod_v1(eod: &EndOfDataV1) -> Value {
let timing = eod.timing();
json!({
"header": {
"version": eod.version(),
"pdu": eod.pdu(),
"session_id": eod.session_id(),
"length": 24
},
"serial_number": eod.serial_number(),
"refresh_interval": timing.refresh,
"retry_interval": timing.retry,
"expire_interval": timing.expire
})
}
pub fn dump_cache_reset(version: u8, pdu: u8) -> Value {
json!({
"header": {
"version": version,
"pdu": pdu,
"session_id": 0,
"length": 8
}
})
}
pub fn v4_prefix(a: u8, b: u8, c: u8, d: u8, prefix_len: u8) -> IPAddressPrefix {
IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(a, b, c, d)),
prefix_length: prefix_len,
}
}
pub fn v6_prefix(addr: Ipv6Addr, prefix_len: u8) -> IPAddressPrefix {
IPAddressPrefix {
address: IPAddress::from_ipv6(addr),
prefix_length: prefix_len,
}
}
pub fn v4_origin(
a: u8,
b: u8,
c: u8,
d: u8,
prefix_len: u8,
max_len: u8,
asn: u32,
) -> RouteOrigin {
let prefix = v4_prefix(a, b, c, d, prefix_len);
RouteOrigin::new(prefix, max_len, asn.into())
}
pub fn v6_origin(
addr: Ipv6Addr,
prefix_len: u8,
max_len: u8,
asn: u32,
) -> RouteOrigin {
let prefix = v6_prefix(addr, prefix_len);
RouteOrigin::new(prefix, max_len, asn.into())
}
pub fn as_route_origin(payload: &Payload) -> &RouteOrigin {
match payload {
Payload::RouteOrigin(ro) => ro,
_ => panic!("expected RouteOrigin payload"),
}
}
pub fn as_v4_route_origin(payload: &Payload) -> &RouteOrigin {
let ro = as_route_origin(payload);
assert!(ro.prefix().address.is_ipv4(), "expected IPv4 RouteOrigin");
ro
}
pub fn as_v6_route_origin(payload: &Payload) -> &RouteOrigin {
let ro = as_route_origin(payload);
assert!(ro.prefix().address.is_ipv6(), "expected IPv6 RouteOrigin");
ro
}
pub fn route_origin_to_string(ro: &RouteOrigin) -> String {
let prefix = ro.prefix();
let addr = match prefix.address {
IPAddress::V4(v4) => v4.to_string(),
IPAddress::V6(v6) => v6.to_string(),
};
format!(
"{}/{}-{} AS{}",
addr,
prefix.prefix_length,
ro.max_length(),
ro.asn().into_u32()
)
}
pub fn payload_to_string(payload: &Payload) -> String {
match payload {
Payload::RouteOrigin(ro) => format!("RouteOrigin({})", route_origin_to_string(ro)),
Payload::RouterKey(_) => "RouterKey(...)".to_string(),
Payload::Aspa(_) => "Aspa(...)".to_string(),
}
}
pub fn payloads_to_pretty_lines(payloads: &[Payload]) -> String {
let mut out = String::new();
for (idx, payload) in payloads.iter().enumerate() {
let _ = writeln!(&mut out, " [{}] {}", idx, payload_to_string(payload));
}
out
}
pub fn print_payloads(label: &str, payloads: &[Payload]) {
println!(
"\n===== {} =====\n{}",
label,
payloads_to_pretty_lines(payloads)
);
}
pub fn serial_result_to_string(result: &SerialResult) -> String {
match result {
SerialResult::UpToDate => "UpToDate".to_string(),
SerialResult::ResetRequired => "ResetRequired".to_string(),
SerialResult::Deltas(deltas) => {
let serials: Vec<u32> = deltas.iter().map(|d| d.serial()).collect();
format!("Deltas {:?}", serials)
}
}
}
pub fn print_serial_result(label: &str, result: &SerialResult) {
println!("\n===== {} =====\n{}\n", label, serial_result_to_string(result));
}
pub fn bytes_to_hex(bytes: &[u8]) -> String {
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(&mut out, "{:02x}", b);
}
out
}
pub fn print_snapshot_hashes(label: &str, snapshot: &rpki::rtr::cache::Snapshot) {
println!(
"\n===== {} =====\norigins_hash={}\nrouter_keys_hash={}\naspas_hash={}\nsnapshot_hash={}\n",
label,
bytes_to_hex(&snapshot.origins_hash()),
bytes_to_hex(&snapshot.router_keys_hash()),
bytes_to_hex(&snapshot.aspas_hash()),
bytes_to_hex(&snapshot.snapshot_hash()),
);
}
pub fn test_report(
name: &str,
purpose: &str,
input: &str,
output: &str,
) {
println!(
"\n==================== TEST REPORT ====================\n测试名称: {}\n测试目的: {}\n\n【输入】\n{}\n【输出】\n{}\n====================================================\n",
name, purpose, input, output
);
}
pub fn payloads_to_string(payloads: &[Payload]) -> String {
let mut out = String::new();
for (idx, payload) in payloads.iter().enumerate() {
let _ = writeln!(&mut out, " [{}] {}", idx, payload_to_string(payload));
}
if out.is_empty() {
out.push_str(" <empty>\n");
}
out
}
pub fn snapshot_hashes_to_string(snapshot: &rpki::rtr::cache::Snapshot) -> String {
format!(
" origins_hash: {}\n router_keys_hash: {}\n aspas_hash: {}\n snapshot_hash: {}\n",
bytes_to_hex(&snapshot.origins_hash()),
bytes_to_hex(&snapshot.router_keys_hash()),
bytes_to_hex(&snapshot.aspas_hash()),
bytes_to_hex(&snapshot.snapshot_hash()),
)
}
pub fn serial_result_detail_to_string(result: &rpki::rtr::cache::SerialResult) -> String {
match result {
rpki::rtr::cache::SerialResult::UpToDate => {
" result: UpToDate\n".to_string()
}
rpki::rtr::cache::SerialResult::ResetRequired => {
" result: ResetRequired\n".to_string()
}
rpki::rtr::cache::SerialResult::Deltas(deltas) => {
let mut out = String::new();
let _ = writeln!(&mut out, " result: Deltas");
for (idx, delta) in deltas.iter().enumerate() {
let _ = writeln!(&mut out, " delta[{}].serial: {}", idx, delta.serial());
let _ = writeln!(&mut out, " delta[{}].announced:", idx);
out.push_str(&indent_block(&payloads_to_string(delta.announced()), 4));
let _ = writeln!(&mut out, " delta[{}].withdrawn:", idx);
out.push_str(&indent_block(&payloads_to_string(delta.withdrawn()), 4));
}
out
}
}
}
pub fn indent_block(text: &str, spaces: usize) -> String {
let pad = " ".repeat(spaces);
let mut out = String::new();
for line in text.lines() {
let _ = writeln!(&mut out, "{}{}", pad, line);
}
out
}

743
tests/test_cache.rs Normal file
View File

@ -0,0 +1,743 @@
mod common;
use std::collections::VecDeque;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use common::test_helper::{
as_route_origin, as_v4_route_origin, indent_block, payloads_to_string,
serial_result_detail_to_string, snapshot_hashes_to_string, test_report, v4_origin, v6_origin,
};
use rpki::rtr::cache::{Delta, RtrCacheBuilder, SerialResult, Snapshot};
use rpki::rtr::payload::{Payload, Timing};
use rpki::rtr::store_db::RtrStore;
fn delta_to_string(delta: &Delta) -> String {
format!(
"serial: {}\nannounced:\n{}withdrawn:\n{}",
delta.serial(),
indent_block(&payloads_to_string(delta.announced()), 2),
indent_block(&payloads_to_string(delta.withdrawn()), 2),
)
}
fn deltas_window_to_string(deltas: &VecDeque<Arc<Delta>>) -> String {
if deltas.is_empty() {
return " <empty>\n".to_string();
}
let mut out = String::new();
for (idx, delta) in deltas.iter().enumerate() {
out.push_str(&format!("delta[{}]:\n", idx));
out.push_str(&indent_block(&delta_to_string(delta), 2));
}
out
}
fn get_deltas_since_input_to_string(
cache_session_id: u16,
cache_serial: u32,
client_session: u16,
client_serial: u32,
) -> String {
format!(
"cache.session_id: {}\ncache.serial: {}\nclient_session: {}\nclient_serial: {}\n",
cache_session_id, cache_serial, client_session, client_serial
)
}
fn snapshot_hashes_and_sorted_view_to_string(snapshot: &Snapshot) -> String {
let payloads = snapshot.payloads_for_rtr();
format!(
"hashes:\n{}sorted payloads_for_rtr:\n{}",
indent_block(&snapshot_hashes_to_string(snapshot), 2),
indent_block(&payloads_to_string(&payloads), 2),
)
}
#[test]
fn snapshot_hash_is_stable_for_same_content_with_different_input_order() {
let a = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let b = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let c = v6_origin(
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0),
32,
48,
64498,
);
let s1_input = vec![
Payload::RouteOrigin(a.clone()),
Payload::RouteOrigin(b.clone()),
Payload::RouteOrigin(c.clone()),
];
let s2_input = vec![
Payload::RouteOrigin(c),
Payload::RouteOrigin(a),
Payload::RouteOrigin(b),
];
let s1 = Snapshot::from_payloads(s1_input.clone());
let s2 = Snapshot::from_payloads(s2_input.clone());
let input = format!(
"s1 原始输入 payloads:\n{}\ns2 原始输入 payloads:\n{}",
indent_block(&payloads_to_string(&s1_input), 2),
indent_block(&payloads_to_string(&s2_input), 2),
);
let output = format!(
"s1:\n{}\ns2:\n{}\n结论:\n same_content: {}\n same_origins: {}\n snapshot_hash 相同: {}\n origins_hash 相同: {}\n",
indent_block(&snapshot_hashes_and_sorted_view_to_string(&s1), 2),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&s2), 2),
s1.same_content(&s2),
s1.same_origins(&s2),
s1.snapshot_hash() == s2.snapshot_hash(),
s1.origins_hash() == s2.origins_hash(),
);
test_report(
"snapshot_hash_is_stable_for_same_content_with_different_input_order",
"验证相同语义内容即使原始输入顺序不同Snapshot 的 hash 仍然稳定一致。",
&input,
&output,
);
assert!(s1.same_content(&s2));
assert!(s1.same_origins(&s2));
assert_eq!(s1.snapshot_hash(), s2.snapshot_hash());
assert_eq!(s1.origins_hash(), s2.origins_hash());
}
#[test]
fn snapshot_diff_reports_announced_and_withdrawn_correctly() {
let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let new_c = v6_origin(
Ipv6Addr::new(0x2001, 0xdb8, 0, 1, 0, 0, 0, 0),
48,
48,
64499,
);
let old_input = vec![
Payload::RouteOrigin(old_a.clone()),
Payload::RouteOrigin(old_b.clone()),
];
let new_input = vec![
Payload::RouteOrigin(old_b),
Payload::RouteOrigin(new_c.clone()),
];
let old_snapshot = Snapshot::from_payloads(old_input.clone());
let new_snapshot = Snapshot::from_payloads(new_input.clone());
let (announced, withdrawn) = old_snapshot.diff(&new_snapshot);
let input = format!(
"old_snapshot 原始输入:\n{}\nnew_snapshot 原始输入:\n{}",
indent_block(&payloads_to_string(&old_input), 2),
indent_block(&payloads_to_string(&new_input), 2),
);
let output = format!(
"announced:\n{}withdrawn:\n{}",
indent_block(&payloads_to_string(&announced), 2),
indent_block(&payloads_to_string(&withdrawn), 2),
);
test_report(
"snapshot_diff_reports_announced_and_withdrawn_correctly",
"验证 diff() 能正确找出 announced 和 withdrawn 的 payload。",
&input,
&output,
);
assert_eq!(announced.len(), 1);
assert_eq!(withdrawn.len(), 1);
match &announced[0] {
Payload::RouteOrigin(ro) => assert_eq!(ro, &new_c),
_ => panic!("expected announced RouteOrigin"),
}
match &withdrawn[0] {
Payload::RouteOrigin(ro) => assert_eq!(ro, &old_a),
_ => panic!("expected withdrawn RouteOrigin"),
}
}
#[test]
fn snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_descending() {
let v4_low = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let v4_high = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let v6 = v6_origin(
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0),
32,
48,
64498,
);
let input_payloads = vec![
Payload::RouteOrigin(v6.clone()),
Payload::RouteOrigin(v4_low.clone()),
Payload::RouteOrigin(v4_high.clone()),
];
let snapshot = Snapshot::from_payloads(input_payloads.clone());
let output_payloads = snapshot.payloads_for_rtr();
let input = format!(
"原始输入 payloads构造 Snapshot 前):\n{}",
indent_block(&payloads_to_string(&input_payloads), 2),
);
let output = format!(
"排序后 payloads_for_rtr:\n{}",
indent_block(&payloads_to_string(&output_payloads), 2),
);
test_report(
"snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_descending",
"验证 Snapshot::payloads_for_rtr() 会按 RTR 规则排序IPv4 在 IPv6 前,且 IPv4 announcement 按地址降序。",
&input,
&output,
);
assert_eq!(output_payloads.len(), 3);
let first = as_v4_route_origin(&output_payloads[0]);
let second = as_v4_route_origin(&output_payloads[1]);
assert_eq!(
first.prefix().address.to_ipv4(),
Some(Ipv4Addr::new(198, 51, 100, 0))
);
assert_eq!(
second.prefix().address.to_ipv4(),
Some(Ipv4Addr::new(192, 0, 2, 0))
);
let third = as_route_origin(&output_payloads[2]);
assert!(third.prefix().address.is_ipv6());
assert_eq!(
third.prefix().address.to_ipv6(),
Some(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0))
);
}
#[test]
fn delta_new_sorts_announced_descending_and_withdrawn_ascending() {
let announced_low = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let announced_high = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let withdrawn_high = v4_origin(203, 0, 113, 0, 24, 24, 64501);
let withdrawn_low = v4_origin(10, 0, 0, 0, 24, 24, 64500);
let input_announced = vec![
Payload::RouteOrigin(announced_low),
Payload::RouteOrigin(announced_high),
];
let input_withdrawn = vec![
Payload::RouteOrigin(withdrawn_high),
Payload::RouteOrigin(withdrawn_low),
];
let delta = Delta::new(101, input_announced.clone(), input_withdrawn.clone());
let input = format!(
"announced构造前:\n{}withdrawn构造前:\n{}",
indent_block(&payloads_to_string(&input_announced), 2),
indent_block(&payloads_to_string(&input_withdrawn), 2),
);
let output = indent_block(&delta_to_string(&delta), 2);
test_report(
"delta_new_sorts_announced_descending_and_withdrawn_ascending",
"验证 Delta::new() 会自动排序announced 按 RTR announcement 规则withdrawn 按 RTR withdrawal 规则。",
&input,
&output,
);
assert_eq!(delta.serial(), 101);
assert_eq!(delta.announced().len(), 2);
assert_eq!(delta.withdrawn().len(), 2);
let a0 = as_v4_route_origin(&delta.announced()[0]);
let a1 = as_v4_route_origin(&delta.announced()[1]);
assert_eq!(
a0.prefix().address.to_ipv4(),
Some(Ipv4Addr::new(198, 51, 100, 0))
);
assert_eq!(
a1.prefix().address.to_ipv4(),
Some(Ipv4Addr::new(192, 0, 2, 0))
);
let w0 = as_v4_route_origin(&delta.withdrawn()[0]);
let w1 = as_v4_route_origin(&delta.withdrawn()[1]);
assert_eq!(w0.prefix().address.to_ipv4(), Some(Ipv4Addr::new(10, 0, 0, 0)));
assert_eq!(
w1.prefix().address.to_ipv4(),
Some(Ipv4Addr::new(203, 0, 113, 0))
);
}
#[test]
fn get_deltas_since_returns_up_to_date_when_client_serial_matches_current() {
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::default())
.build();
let result = cache.get_deltas_since(42, 100);
let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 100);
let output = serial_result_detail_to_string(&result);
test_report(
"get_deltas_since_returns_up_to_date_when_client_serial_matches_current",
"验证当客户端 serial 与缓存当前 serial 相同,返回 UpToDate。",
&input,
&output,
);
match result {
SerialResult::UpToDate => {}
_ => panic!("expected UpToDate"),
}
}
#[test]
fn get_deltas_since_returns_reset_required_on_session_mismatch() {
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::default())
.build();
let result = cache.get_deltas_since(999, 100);
let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 999, 100);
let output = serial_result_detail_to_string(&result);
test_report(
"get_deltas_since_returns_reset_required_on_session_mismatch",
"验证当客户端 session_id 与缓存 session_id 不一致时,返回 ResetRequired。",
&input,
&output,
);
match result {
SerialResult::ResetRequired => {}
_ => panic!("expected ResetRequired"),
}
}
#[test]
fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old() {
let d1 = Arc::new(Delta::new(
101,
vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))],
vec![],
));
let d2 = Arc::new(Delta::new(
102,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![],
));
let mut deltas = VecDeque::new();
deltas.push_back(d1);
deltas.push_back(d2);
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(102)
.timing(Timing::default())
.deltas(deltas.clone())
.build();
let result = cache.get_deltas_since(42, 99);
let input = format!(
"{}delta_window:\n{}",
get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 99),
indent_block(&deltas_window_to_string(&deltas), 2),
);
let output = serial_result_detail_to_string(&result);
test_report(
"get_deltas_since_returns_reset_required_when_client_serial_is_too_old",
"验证当客户端 serial 太旧,已超出 delta window 覆盖范围时,返回 ResetRequired。",
&input,
&output,
);
match result {
SerialResult::ResetRequired => {}
_ => panic!("expected ResetRequired"),
}
}
#[test]
fn get_deltas_since_returns_applicable_deltas_in_order() {
let d1 = Arc::new(Delta::new(
101,
vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))],
vec![],
));
let d2 = Arc::new(Delta::new(
102,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![],
));
let d3 = Arc::new(Delta::new(
103,
vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))],
vec![],
));
let mut deltas = VecDeque::new();
deltas.push_back(d1);
deltas.push_back(d2);
deltas.push_back(d3);
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(103)
.timing(Timing::default())
.deltas(deltas.clone())
.build();
let result = cache.get_deltas_since(42, 101);
let input = format!(
"{}delta_window:\n{}",
get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 101),
indent_block(&deltas_window_to_string(&deltas), 2),
);
let output = serial_result_detail_to_string(&result);
test_report(
"get_deltas_since_returns_applicable_deltas_in_order",
"验证当客户端 serial 在 delta window 内时,返回正确且有序的 deltas。",
&input,
&output,
);
match result {
SerialResult::Deltas(result) => {
assert_eq!(result.len(), 2);
assert_eq!(result[0].serial(), 102);
assert_eq!(result[1].serial(), 103);
}
_ => panic!("expected Deltas"),
}
}
#[test]
fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future() {
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::default())
.build();
let result = cache.get_deltas_since(42, 101);
let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 101);
let output = serial_result_detail_to_string(&result);
test_report(
"get_deltas_since_returns_reset_required_when_client_serial_is_in_future",
"验证当客户端 serial 比缓存当前 serial 还大时,返回 ResetRequired。",
&input,
&output,
);
match result {
SerialResult::ResetRequired => {}
_ => panic!("expected ResetRequired"),
}
}
#[tokio::test]
async fn update_no_change_keeps_serial_and_produces_no_delta() {
let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let old_input = vec![
Payload::RouteOrigin(old_a.clone()),
Payload::RouteOrigin(old_b.clone()),
];
let snapshot = Snapshot::from_payloads(old_input.clone());
let mut cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::default())
.snapshot(snapshot.clone())
.build();
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let new_payloads = vec![
Payload::RouteOrigin(old_b),
Payload::RouteOrigin(old_a),
];
cache.update(new_payloads.clone(), &store).unwrap();
let current_snapshot = cache.snapshot();
let result = cache.get_deltas_since(42, 100);
let input = format!(
"old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}",
indent_block(&payloads_to_string(&old_input), 2),
indent_block(&payloads_to_string(&new_payloads), 2),
);
let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}",
cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2),
indent_block(&serial_result_detail_to_string(&result), 2),
);
test_report(
"update_no_change_keeps_serial_and_produces_no_delta",
"验证 update() 在新旧内容完全相同时serial 不变、snapshot 不变、不会产生新的 delta。",
&input,
&output,
);
assert_eq!(cache.serial(), 100);
assert!(cache.snapshot().same_content(&snapshot));
match result {
SerialResult::UpToDate => {}
_ => panic!("expected UpToDate"),
}
}
#[tokio::test]
async fn update_add_only_increments_serial_and_generates_announced_delta() {
let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let new_b = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let old_input = vec![Payload::RouteOrigin(old_a.clone())];
let old_snapshot = Snapshot::from_payloads(old_input.clone());
let mut cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::default())
.snapshot(old_snapshot.clone())
.build();
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let new_payloads = vec![
Payload::RouteOrigin(old_a.clone()),
Payload::RouteOrigin(new_b.clone()),
];
cache.update(new_payloads.clone(), &store).unwrap();
let current_snapshot = cache.snapshot();
let result = cache.get_deltas_since(42, 100);
let input = format!(
"old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}",
indent_block(&payloads_to_string(&old_input), 2),
indent_block(&payloads_to_string(&new_payloads), 2),
);
let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}",
cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2),
indent_block(&serial_result_detail_to_string(&result), 2),
);
test_report(
"update_add_only_increments_serial_and_generates_announced_delta",
"验证 update() 在只新增 payload 时serial + 1delta 中只有 announcedwithdrawn 为空。",
&input,
&output,
);
assert_eq!(cache.serial(), 101);
match result {
SerialResult::Deltas(deltas) => {
assert_eq!(deltas.len(), 1);
let delta = &deltas[0];
assert_eq!(delta.serial(), 101);
assert_eq!(delta.announced().len(), 1);
assert_eq!(delta.withdrawn().len(), 0);
match &delta.announced()[0] {
Payload::RouteOrigin(ro) => assert_eq!(ro, &new_b),
_ => panic!("expected announced RouteOrigin"),
}
}
_ => panic!("expected Deltas"),
}
}
#[tokio::test]
async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() {
let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let old_input = vec![
Payload::RouteOrigin(old_a.clone()),
Payload::RouteOrigin(old_b.clone()),
];
let old_snapshot = Snapshot::from_payloads(old_input.clone());
let mut cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::default())
.snapshot(old_snapshot.clone())
.build();
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let new_payloads = vec![Payload::RouteOrigin(old_b.clone())];
cache.update(new_payloads.clone(), &store).unwrap();
let current_snapshot = cache.snapshot();
let result = cache.get_deltas_since(42, 100);
let input = format!(
"old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}",
indent_block(&payloads_to_string(&old_input), 2),
indent_block(&payloads_to_string(&new_payloads), 2),
);
let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}",
cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2),
indent_block(&serial_result_detail_to_string(&result), 2),
);
test_report(
"update_remove_only_increments_serial_and_generates_withdrawn_delta",
"验证 update() 在只删除 payload 时serial + 1delta 中只有 withdrawnannounced 为空。",
&input,
&output,
);
assert_eq!(cache.serial(), 101);
match result {
SerialResult::Deltas(deltas) => {
assert_eq!(deltas.len(), 1);
let delta = &deltas[0];
assert_eq!(delta.serial(), 101);
assert_eq!(delta.announced().len(), 0);
assert_eq!(delta.withdrawn().len(), 1);
match &delta.withdrawn()[0] {
Payload::RouteOrigin(ro) => assert_eq!(ro, &old_a),
_ => panic!("expected withdrawn RouteOrigin"),
}
}
_ => panic!("expected Deltas"),
}
}
#[tokio::test]
async fn update_add_and_remove_increments_serial_and_generates_both_sides() {
let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496);
let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497);
let new_c = v6_origin(
Ipv6Addr::new(0x2001, 0xdb8, 0, 1, 0, 0, 0, 0),
48,
48,
64499,
);
let old_input = vec![
Payload::RouteOrigin(old_a.clone()),
Payload::RouteOrigin(old_b.clone()),
];
let old_snapshot = Snapshot::from_payloads(old_input.clone());
let mut cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::default())
.snapshot(old_snapshot.clone())
.build();
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let new_payloads = vec![
Payload::RouteOrigin(old_b.clone()),
Payload::RouteOrigin(new_c.clone()),
];
cache.update(new_payloads.clone(), &store).unwrap();
let current_snapshot = cache.snapshot();
let result = cache.get_deltas_since(42, 100);
let input = format!(
"old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}",
indent_block(&payloads_to_string(&old_input), 2),
indent_block(&payloads_to_string(&new_payloads), 2),
);
let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}",
cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2),
indent_block(&serial_result_detail_to_string(&result), 2),
);
test_report(
"update_add_and_remove_increments_serial_and_generates_both_sides",
"验证 update() 在同时新增和删除 payload 时serial + 1delta 中 announced 和 withdrawn 都正确。",
&input,
&output,
);
assert_eq!(cache.serial(), 101);
match result {
SerialResult::Deltas(deltas) => {
assert_eq!(deltas.len(), 1);
let delta = &deltas[0];
assert_eq!(delta.serial(), 101);
assert_eq!(delta.announced().len(), 1);
assert_eq!(delta.withdrawn().len(), 1);
match &delta.announced()[0] {
Payload::RouteOrigin(ro) => assert_eq!(ro, &new_c),
_ => panic!("expected announced RouteOrigin"),
}
match &delta.withdrawn()[0] {
Payload::RouteOrigin(ro) => assert_eq!(ro, &old_a),
_ => panic!("expected withdrawn RouteOrigin"),
}
}
_ => panic!("expected Deltas"),
}
}

492
tests/test_session.rs Normal file
View File

@ -0,0 +1,492 @@
mod common;
use std::collections::VecDeque;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::{Arc, RwLock};
use tokio::net::TcpListener;
use tokio::sync::{broadcast, watch};
use common::test_helper::{
dump_cache_reset, dump_cache_response, dump_eod_v1, dump_ipv4_prefix, dump_ipv6_prefix,
RtrDebugDumper,
};
use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
use rpki::rtr::cache::{Delta, SharedRtrCache, RtrCacheBuilder, Snapshot};
use rpki::rtr::payload::{Payload, RouteOrigin, Timing};
use rpki::rtr::pdu::{
CacheResponse, CacheReset, EndOfDataV1, IPv4Prefix, IPv6Prefix, ResetQuery, SerialQuery,
};
use rpki::rtr::session::RtrSession;
fn shared_cache(cache: rpki::rtr::cache::RtrCache) -> SharedRtrCache {
Arc::new(RwLock::new(cache))
}
#[tokio::test]
async fn reset_query_returns_snapshot_and_end_of_data() {
let prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)),
prefix_length: 24,
};
let origin = RouteOrigin::new(prefix, 24, 64496u32.into());
let snapshot = Snapshot::from_payloads(vec![Payload::RouteOrigin(origin)]);
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::new(600, 600, 7200))
.snapshot(snapshot)
.build();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_cache = shared_cache(cache);
let (_notify_tx, notify_rx) = broadcast::channel(16);
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx);
session.run().await.unwrap();
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
ResetQuery::new(1).write(&mut client).await.unwrap();
let mut dump = RtrDebugDumper::new();
let response = CacheResponse::read(&mut client).await.unwrap();
dump.push_value(response.pdu(), dump_cache_response(&response));
assert_eq!(response.pdu(), 3);
assert_eq!(response.version(), 1);
assert_eq!(response.session_id(), 42);
let prefix = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(prefix.pdu(), dump_ipv4_prefix(&prefix));
assert_eq!(prefix.pdu(), 4);
assert_eq!(prefix.version(), 1);
assert!(prefix.flag().is_announce());
assert_eq!(prefix.prefix_len(), 24);
assert_eq!(prefix.max_len(), 24);
assert_eq!(prefix.prefix(), Ipv4Addr::new(192, 0, 2, 0));
assert_eq!(prefix.asn(), 64496u32.into());
let eod = EndOfDataV1::read(&mut client).await.unwrap();
dump.push_value(eod.pdu(), dump_eod_v1(&eod));
assert_eq!(eod.pdu(), 7);
assert_eq!(eod.version(), 1);
assert_eq!(eod.session_id(), 42);
assert_eq!(eod.serial_number(), 100);
let timing = eod.timing();
assert_eq!(timing.refresh, 600);
assert_eq!(timing.retry, 600);
assert_eq!(timing.expire, 7200);
dump.print_pretty("reset_query_returns_snapshot_and_end_of_data");
}
#[tokio::test]
async fn serial_query_returns_end_of_data_when_up_to_date() {
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing {
refresh: 600,
retry: 600,
expire: 7200,
})
.build();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_cache = shared_cache(cache);
let (_notify_tx, notify_rx) = broadcast::channel(16);
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx);
session.run().await.unwrap();
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap();
let mut dump = RtrDebugDumper::new();
let eod = EndOfDataV1::read(&mut client).await.unwrap();
dump.push_value(eod.pdu(), dump_eod_v1(&eod));
assert_eq!(eod.pdu(), 7);
assert_eq!(eod.version(), 1);
assert_eq!(eod.session_id(), 42);
assert_eq!(eod.serial_number(), 100);
let timing = eod.timing();
assert_eq!(timing.refresh, 600);
assert_eq!(timing.retry, 600);
assert_eq!(timing.expire, 7200);
dump.print_pretty("serial_query_returns_end_of_data_when_up_to_date");
}
#[tokio::test]
async fn serial_query_returns_cache_reset_when_session_id_mismatch() {
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing {
refresh: 600,
retry: 600,
expire: 7200,
})
.build();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_cache = shared_cache(cache);
let (_notify_tx, notify_rx) = broadcast::channel(16);
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx);
session.run().await.unwrap();
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
SerialQuery::new(1, 999, 100).write(&mut client).await.unwrap();
let mut dump = RtrDebugDumper::new();
let reset = CacheReset::read(&mut client).await.unwrap();
dump.push_value(reset.pdu(), dump_cache_reset(reset.version(), reset.pdu()));
assert_eq!(reset.pdu(), 8);
assert_eq!(reset.version(), 1);
dump.print_pretty("serial_query_returns_cache_reset_when_session_id_mismatch");
}
#[tokio::test]
async fn serial_query_returns_deltas_when_incremental_update_available() {
let prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)),
prefix_length: 24,
};
let origin = RouteOrigin::new(prefix, 24, 64496u32.into());
let delta = Arc::new(Delta::new(
101,
vec![Payload::RouteOrigin(origin)],
vec![],
));
let mut deltas = VecDeque::new();
deltas.push_back(delta);
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(101)
.timing(Timing {
refresh: 600,
retry: 600,
expire: 7200,
})
.deltas(deltas)
.build();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_cache = shared_cache(cache);
let (_notify_tx, notify_rx) = broadcast::channel(16);
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx);
session.run().await.unwrap();
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap();
let mut dump = RtrDebugDumper::new();
let response = CacheResponse::read(&mut client).await.unwrap();
dump.push_value(response.pdu(), dump_cache_response(&response));
assert_eq!(response.pdu(), 3);
assert_eq!(response.version(), 1);
assert_eq!(response.session_id(), 42);
let prefix = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(prefix.pdu(), dump_ipv4_prefix(&prefix));
assert_eq!(prefix.pdu(), 4);
assert_eq!(prefix.version(), 1);
assert!(prefix.flag().is_announce());
assert_eq!(prefix.prefix_len(), 24);
assert_eq!(prefix.max_len(), 24);
assert_eq!(prefix.prefix(), Ipv4Addr::new(192, 0, 2, 0));
assert_eq!(prefix.asn(), 64496u32.into());
let eod = EndOfDataV1::read(&mut client).await.unwrap();
dump.push_value(eod.pdu(), dump_eod_v1(&eod));
assert_eq!(eod.pdu(), 7);
assert_eq!(eod.version(), 1);
assert_eq!(eod.session_id(), 42);
assert_eq!(eod.serial_number(), 101);
let timing = eod.timing();
assert_eq!(timing.refresh, 600);
assert_eq!(timing.retry, 600);
assert_eq!(timing.expire, 7200);
dump.print_pretty("serial_query_returns_deltas_when_incremental_update_available");
}
#[tokio::test]
async fn reset_query_returns_payloads_in_rtr_order() {
let v4_low_prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)),
prefix_length: 24,
};
let v4_low_origin = RouteOrigin::new(v4_low_prefix, 24, 64496u32.into());
let v4_high_prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(198, 51, 100, 0)),
prefix_length: 24,
};
let v4_high_origin = RouteOrigin::new(v4_high_prefix, 24, 64497u32.into());
let v6_prefix = IPAddressPrefix {
address: IPAddress::from_ipv6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)),
prefix_length: 32,
};
let v6_origin = RouteOrigin::new(v6_prefix, 48, 64498u32.into());
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v6_origin),
Payload::RouteOrigin(v4_low_origin),
Payload::RouteOrigin(v4_high_origin),
]);
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(100)
.timing(Timing::new(600, 600, 7200))
.snapshot(snapshot)
.build();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_cache = shared_cache(cache);
let (_notify_tx, notify_rx) = broadcast::channel(16);
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx);
session.run().await.unwrap();
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
ResetQuery::new(1).write(&mut client).await.unwrap();
let mut dump = RtrDebugDumper::new();
let response = CacheResponse::read(&mut client).await.unwrap();
dump.push_value(response.pdu(), dump_cache_response(&response));
assert_eq!(response.pdu(), 3);
assert_eq!(response.version(), 1);
assert_eq!(response.session_id(), 42);
let first = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(first.pdu(), dump_ipv4_prefix(&first));
assert_eq!(first.pdu(), 4);
assert_eq!(first.version(), 1);
assert!(first.flag().is_announce());
assert_eq!(first.prefix(), Ipv4Addr::new(198, 51, 100, 0));
assert_eq!(first.prefix_len(), 24);
assert_eq!(first.max_len(), 24);
assert_eq!(first.asn(), 64497u32.into());
let second = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(second.pdu(), dump_ipv4_prefix(&second));
assert_eq!(second.pdu(), 4);
assert_eq!(second.version(), 1);
assert!(second.flag().is_announce());
assert_eq!(second.prefix(), Ipv4Addr::new(192, 0, 2, 0));
assert_eq!(second.prefix_len(), 24);
assert_eq!(second.max_len(), 24);
assert_eq!(second.asn(), 64496u32.into());
assert!(u32::from(first.prefix()) > u32::from(second.prefix()));
let third = IPv6Prefix::read(&mut client).await.unwrap();
dump.push_value(third.pdu(), dump_ipv6_prefix(&third));
assert_eq!(third.pdu(), 6);
assert_eq!(third.version(), 1);
assert!(third.flag().is_announce());
assert_eq!(third.prefix(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0));
assert_eq!(third.prefix_len(), 32);
assert_eq!(third.max_len(), 48);
assert_eq!(third.asn(), 64498u32.into());
let eod = EndOfDataV1::read(&mut client).await.unwrap();
dump.push_value(eod.pdu(), dump_eod_v1(&eod));
assert_eq!(eod.pdu(), 7);
assert_eq!(eod.version(), 1);
assert_eq!(eod.session_id(), 42);
assert_eq!(eod.serial_number(), 100);
let timing = eod.timing();
assert_eq!(timing.refresh, 600);
assert_eq!(timing.retry, 600);
assert_eq!(timing.expire, 7200);
dump.print_pretty("reset_query_returns_payloads_in_rtr_order");
}
#[tokio::test]
async fn serial_query_returns_announcements_before_withdrawals() {
let announced_low_prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)),
prefix_length: 24,
};
let announced_low_origin = RouteOrigin::new(announced_low_prefix, 24, 64496u32.into());
let announced_high_prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(198, 51, 100, 0)),
prefix_length: 24,
};
let announced_high_origin = RouteOrigin::new(announced_high_prefix, 24, 64497u32.into());
let withdrawn_low_prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(10, 0, 0, 0)),
prefix_length: 24,
};
let withdrawn_low_origin = RouteOrigin::new(withdrawn_low_prefix, 24, 64500u32.into());
let withdrawn_high_prefix = IPAddressPrefix {
address: IPAddress::from_ipv4(Ipv4Addr::new(203, 0, 113, 0)),
prefix_length: 24,
};
let withdrawn_high_origin = RouteOrigin::new(withdrawn_high_prefix, 24, 64501u32.into());
let delta = Arc::new(Delta::new(
101,
vec![
Payload::RouteOrigin(announced_low_origin),
Payload::RouteOrigin(announced_high_origin),
],
vec![
Payload::RouteOrigin(withdrawn_high_origin),
Payload::RouteOrigin(withdrawn_low_origin),
],
));
let mut deltas = VecDeque::new();
deltas.push_back(delta);
let cache = RtrCacheBuilder::new()
.session_id(42)
.serial(101)
.timing(Timing {
refresh: 600,
retry: 600,
expire: 7200,
})
.deltas(deltas)
.build();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_cache = shared_cache(cache);
let (_notify_tx, notify_rx) = broadcast::channel(16);
let (_shutdown_tx, shutdown_rx) = watch::channel(false);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx);
session.run().await.unwrap();
});
let mut client = tokio::net::TcpStream::connect(addr).await.unwrap();
SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap();
let mut dump = RtrDebugDumper::new();
let response = CacheResponse::read(&mut client).await.unwrap();
dump.push_value(response.pdu(), dump_cache_response(&response));
assert_eq!(response.pdu(), 3);
assert_eq!(response.version(), 1);
assert_eq!(response.session_id(), 42);
let first = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(first.pdu(), dump_ipv4_prefix(&first));
assert_eq!(first.pdu(), 4);
assert_eq!(first.version(), 1);
assert!(first.flag().is_announce());
assert_eq!(first.prefix(), Ipv4Addr::new(198, 51, 100, 0));
assert_eq!(first.prefix_len(), 24);
assert_eq!(first.max_len(), 24);
assert_eq!(first.asn(), 64497u32.into());
let second = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(second.pdu(), dump_ipv4_prefix(&second));
assert_eq!(second.pdu(), 4);
assert_eq!(second.version(), 1);
assert!(second.flag().is_announce());
assert_eq!(second.prefix(), Ipv4Addr::new(192, 0, 2, 0));
assert_eq!(second.prefix_len(), 24);
assert_eq!(second.max_len(), 24);
assert_eq!(second.asn(), 64496u32.into());
assert!(u32::from(first.prefix()) > u32::from(second.prefix()));
let third = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(third.pdu(), dump_ipv4_prefix(&third));
assert_eq!(third.pdu(), 4);
assert_eq!(third.version(), 1);
assert!(!third.flag().is_announce());
assert_eq!(third.prefix(), Ipv4Addr::new(10, 0, 0, 0));
assert_eq!(third.prefix_len(), 24);
assert_eq!(third.max_len(), 24);
assert_eq!(third.asn(), 64500u32.into());
let fourth = IPv4Prefix::read(&mut client).await.unwrap();
dump.push_value(fourth.pdu(), dump_ipv4_prefix(&fourth));
assert_eq!(fourth.pdu(), 4);
assert_eq!(fourth.version(), 1);
assert!(!fourth.flag().is_announce());
assert_eq!(fourth.prefix(), Ipv4Addr::new(203, 0, 113, 0));
assert_eq!(fourth.prefix_len(), 24);
assert_eq!(fourth.max_len(), 24);
assert_eq!(fourth.asn(), 64501u32.into());
assert!(u32::from(third.prefix()) < u32::from(fourth.prefix()));
assert!(first.flag().is_announce());
assert!(second.flag().is_announce());
assert!(!third.flag().is_announce());
assert!(!fourth.flag().is_announce());
let eod = EndOfDataV1::read(&mut client).await.unwrap();
dump.push_value(eod.pdu(), dump_eod_v1(&eod));
assert_eq!(eod.pdu(), 7);
assert_eq!(eod.version(), 1);
assert_eq!(eod.session_id(), 42);
assert_eq!(eod.serial_number(), 101);
let timing = eod.timing();
assert_eq!(timing.refresh, 600);
assert_eq!(timing.retry, 600);
assert_eq!(timing.expire, 7200);
dump.print_pretty("serial_query_returns_announcements_before_withdrawals");
}

352
tests/test_store_db.rs Normal file
View File

@ -0,0 +1,352 @@
mod common;
use std::net::Ipv6Addr;
use common::test_helper::{
indent_block, payloads_to_string, test_report, v4_origin, v6_origin,
};
use rpki::rtr::cache::{Delta, Snapshot};
use rpki::rtr::payload::Payload;
use rpki::rtr::store_db::RtrStore;
fn snapshot_to_string(snapshot: &Snapshot) -> String {
let payloads = snapshot.payloads_for_rtr();
payloads_to_string(&payloads)
}
fn delta_to_string(delta: &Delta) -> String {
format!(
"serial: {}\nannounced:\n{}withdrawn:\n{}",
delta.serial(),
indent_block(&payloads_to_string(delta.announced()), 2),
indent_block(&payloads_to_string(delta.withdrawn()), 2),
)
}
#[test]
fn store_db_save_and_get_snapshot() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let input_payloads = vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
Payload::RouteOrigin(v6_origin(
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0),
32,
48,
64497,
)),
];
let snapshot = Snapshot::from_payloads(input_payloads.clone());
store.save_snapshot(&snapshot).unwrap();
let loaded = store.get_snapshot().unwrap().expect("snapshot should exist");
let input = format!(
"db_path: {}\nsnapshot:\n{}",
dir.path().display(),
indent_block(&payloads_to_string(&input_payloads), 2),
);
let output = format!(
"loaded snapshot:\n{}same_content: {}\n",
indent_block(&snapshot_to_string(&loaded), 2),
snapshot.same_content(&loaded),
);
test_report(
"store_db_save_and_get_snapshot",
"验证 save_snapshot() 后可以通过 get_snapshot() 正确读回 Snapshot。",
&input,
&output,
);
assert!(snapshot.same_content(&loaded));
}
#[test]
fn store_db_set_and_get_meta_fields() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
store.set_session_id(42).unwrap();
store.set_serial(100).unwrap();
store.set_delta_window(101, 110).unwrap();
let session_id = store.get_session_id().unwrap();
let serial = store.get_serial().unwrap();
let window = store.get_delta_window().unwrap();
let input = format!(
"db_path: {}\nset_session_id=42\nset_serial=100\nset_delta_window=(101, 110)\n",
dir.path().display(),
);
let output = format!(
"get_session_id: {:?}\nget_serial: {:?}\nget_delta_window: {:?}\n",
session_id, serial, window,
);
test_report(
"store_db_set_and_get_meta_fields",
"验证 session_id / serial / delta_window 能正确写入并读回。",
&input,
&output,
);
assert_eq!(session_id, Some(42));
assert_eq!(serial, Some(100));
assert_eq!(window, Some((101, 110)));
}
#[test]
fn store_db_save_and_get_delta() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let delta = Delta::new(
101,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))],
);
store.save_delta(&delta).unwrap();
let loaded = store.get_delta(101).unwrap().expect("delta should exist");
let input = format!(
"db_path: {}\ndelta:\n{}",
dir.path().display(),
indent_block(&delta_to_string(&delta), 2),
);
let output = format!(
"loaded delta:\n{}",
indent_block(&delta_to_string(&loaded), 2),
);
test_report(
"store_db_save_and_get_delta",
"验证 save_delta() 后可以通过 get_delta(serial) 正确读回 Delta。",
&input,
&output,
);
assert_eq!(loaded.serial(), 101);
assert_eq!(loaded.announced().len(), 1);
assert_eq!(loaded.withdrawn().len(), 1);
}
#[test]
fn store_db_load_deltas_since_returns_only_newer_deltas_in_order() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let d101 = Delta::new(
101,
vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))],
vec![],
);
let d102 = Delta::new(
102,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![],
);
let d103 = Delta::new(
103,
vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))],
vec![],
);
store.save_delta(&d101).unwrap();
store.save_delta(&d102).unwrap();
store.save_delta(&d103).unwrap();
let loaded = store.load_deltas_since(101).unwrap();
let input = format!(
"db_path: {}\nsaved delta serials: [101, 102, 103]\nload_deltas_since(101)\n",
dir.path().display(),
);
let output = {
let mut s = String::new();
for (idx, d) in loaded.iter().enumerate() {
s.push_str(&format!("loaded[{}]:\n", idx));
s.push_str(&indent_block(&delta_to_string(d), 2));
}
s
};
test_report(
"store_db_load_deltas_since_returns_only_newer_deltas_in_order",
"验证 load_deltas_since(x) 只返回 serial > x 的 Delta且顺序正确。",
&input,
&output,
);
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0].serial(), 102);
assert_eq!(loaded[1].serial(), 103);
}
#[test]
fn store_db_save_snapshot_and_meta_writes_all_fields() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)),
]);
store.save_snapshot_and_meta(&snapshot, 42, 100).unwrap();
let loaded_snapshot = store.get_snapshot().unwrap().expect("snapshot should exist");
let loaded_session = store.get_session_id().unwrap();
let loaded_serial = store.get_serial().unwrap();
let input = format!(
"db_path: {}\nsnapshot:\n{}session_id=42\nserial=100\n",
dir.path().display(),
indent_block(&snapshot_to_string(&snapshot), 2),
);
let output = format!(
"loaded_snapshot:\n{}loaded_session_id: {:?}\nloaded_serial: {:?}\n",
indent_block(&snapshot_to_string(&loaded_snapshot), 2),
loaded_session,
loaded_serial,
);
test_report(
"store_db_save_snapshot_and_meta_writes_all_fields",
"验证 save_snapshot_and_meta() 会同时写入 snapshot、session_id 和 serial。",
&input,
&output,
);
assert!(snapshot.same_content(&loaded_snapshot));
assert_eq!(loaded_session, Some(42));
assert_eq!(loaded_serial, Some(100));
}
#[test]
fn store_db_load_snapshot_and_serial_returns_consistent_pair() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498)),
]);
store.save_snapshot_and_serial(&snapshot, 200).unwrap();
let loaded = store
.load_snapshot_and_serial()
.unwrap()
.expect("snapshot+serial should exist");
let input = format!(
"db_path: {}\nsnapshot:\n{}serial=200\n",
dir.path().display(),
indent_block(&snapshot_to_string(&snapshot), 2),
);
let output = format!(
"loaded_snapshot:\n{}loaded_serial: {}\n",
indent_block(&snapshot_to_string(&loaded.0), 2),
loaded.1,
);
test_report(
"store_db_load_snapshot_and_serial_returns_consistent_pair",
"验证 load_snapshot_and_serial() 能正确返回一致的 snapshot 与 serial。",
&input,
&output,
);
assert!(snapshot.same_content(&loaded.0));
assert_eq!(loaded.1, 200);
}
#[test]
fn store_db_delete_snapshot_delta_and_serial_removes_data() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
]);
let delta = Delta::new(
101,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))],
vec![],
);
store.save_snapshot(&snapshot).unwrap();
store.save_delta(&delta).unwrap();
store.set_serial(100).unwrap();
store.delete_snapshot().unwrap();
store.delete_delta(101).unwrap();
store.delete_serial().unwrap();
let loaded_snapshot = store.get_snapshot().unwrap();
let loaded_delta = store.get_delta(101).unwrap();
let loaded_serial = store.get_serial().unwrap();
let input = format!(
"db_path: {}\nsave snapshot + delta(101) + serial(100), then delete all three.\n",
dir.path().display(),
);
let output = format!(
"get_snapshot: {:?}\nget_delta(101): {:?}\nget_serial: {:?}\n",
loaded_snapshot.as_ref().map(|_| "Some(snapshot)"),
loaded_delta.as_ref().map(|_| "Some(delta)"),
loaded_serial,
);
test_report(
"store_db_delete_snapshot_delta_and_serial_removes_data",
"验证 delete_snapshot()/delete_delta()/delete_serial() 后,对应数据不再可读。",
&input,
&output,
);
assert!(loaded_snapshot.is_none());
assert!(loaded_delta.is_none());
assert!(loaded_serial.is_none());
}
#[test]
fn store_db_load_snapshot_and_serial_errors_on_inconsistent_state() {
let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap();
let snapshot = Snapshot::from_payloads(vec![
Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)),
]);
store.save_snapshot(&snapshot).unwrap();
// 故意不写 serial制造不一致状态
let result = store.load_snapshot_and_serial();
let input = format!(
"db_path: {}\n仅保存 snapshot不保存 serial。\n",
dir.path().display(),
);
let output = format!("load_snapshot_and_serial result: {:?}\n", result);
test_report(
"store_db_load_snapshot_and_serial_errors_on_inconsistent_state",
"验证当 snapshot 和 serial 状态不一致时load_snapshot_and_serial() 返回错误。",
&input,
&output,
);
assert!(result.is_err());
}