diff --git a/.gitignore b/.gitignore index 2c96eb1..6e61639 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target/ Cargo.lock +rtr-db/ \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index fa7ca1c..16e3f63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,15 @@ chrono = "0.4.44" bytes = "1.11.1" tokio = { version = "1.49.0", features = ["full"] } rand = "0.10.0" -rocksdb = "0.21" -serde = { version = "1", features = ["derive"] } +rocksdb = { version = "0.21.0", default-features = false } +serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" anyhow = "1" -bincode = "3.0.0" -tracing = "0.1.44" \ No newline at end of file +tracing = "0.1.44" +sha2 = "0.10" +tempfile = "3" +tokio-rustls = "0.26" +rustls = "0.23" +rustls-pemfile = "2" +rustls-pki-types = "1.14.0" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } \ No newline at end of file diff --git a/src/bin/rtr_debug_client/README.md b/src/bin/rtr_debug_client/README.md new file mode 100644 index 0000000..9d95b97 --- /dev/null +++ b/src/bin/rtr_debug_client/README.md @@ -0,0 +1,63 @@ +# rtr_debug_client + +`rtr_debug_client` 是一个用于调试和联调 RTR(RPKI-to-Router)服务端的小型客户端工具。 + +它的目标不是做一个完整的生产级 router client,而是提供一个简单、直接、可观察的调试入口,用于: + +- 连接 RTR server +- 发送 `Reset Query` 或 `Serial Query` +- 接收并打印服务端返回的 PDU +- 辅助排查协议实现、会话状态、序列号增量、PDU 编码等问题 + +--- + +## 适用场景 + +这个工具适合以下场景: + +- 开发 RTR server 时做本地联调 +- 验证服务端是否正确返回 `Cache Response` +- 检查 `IPv4 Prefix` / `IPv6 Prefix` / `ASPA` / `End of Data` 等 PDU +- 验证 `Serial Query` 路径是否正确 +- 观察异常响应,例如 `Cache Reset` 或 `Error Report` +- 后续扩展为支持 TLS、自动断言、会话统计等调试能力 + +--- + +## 当前能力 + +当前版本支持: + +- TCP 连接 RTR server +- 发送 `Reset Query` +- 发送 `Serial Query` +- 持续读取服务端返回的 PDU +- 解析并打印以下常见 PDU: + - `Serial Notify` + - `Serial Query` + - `Reset Query` + - `Cache Response` + - `IPv4 Prefix` + - `IPv6 Prefix` + - `End of Data` + - `Cache Reset` + - `Error Report` + - `ASPA` +- 基础长度校验 +- 最大 PDU 长度限制,防止异常数据导致过大内存分配 + +--- + +## 目录结构 + +建议目录如下: + +```text +src/ +└── bin/ + └── rtr_debug_client/ + ├── main.rs + ├── protocol.rs + ├── io.rs + ├── pretty.rs + └── README.md \ No newline at end of file diff --git a/src/bin/rtr_debug_client/main.rs b/src/bin/rtr_debug_client/main.rs new file mode 100644 index 0000000..4f2cee9 --- /dev/null +++ b/src/bin/rtr_debug_client/main.rs @@ -0,0 +1,582 @@ +use std::env; +use std::io; + +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::TcpStream; +use tokio::time::{timeout, Duration, Instant}; + +mod wire; +mod pretty; +mod protocol; + +use crate::wire::{read_pdu, send_reset_query, send_serial_query}; +use crate::pretty::{ + parse_end_of_data_info, parse_serial_notify_serial, print_pdu, +}; +use crate::protocol::{PduHeader, PduType, QueryMode}; + +const DEFAULT_READ_TIMEOUT_SECS: u64 = 30; +const DEFAULT_POLL_INTERVAL_SECS: u64 = 60; + +#[tokio::main] +async fn main() -> io::Result<()> { + let config = Config::from_args()?; + + println!("== RTR debug client =="); + println!("target : {}", config.addr); + println!("version : {}", config.version); + println!("timeout : {}s", config.read_timeout_secs); + println!("poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs); + match &config.mode { + QueryMode::Reset => { + println!("mode : reset"); + } + QueryMode::Serial { session_id, serial } => { + println!("mode : serial"); + println!("session : {}", session_id); + println!("serial : {}", serial); + } + } + println!(); + print_help(); + + let stream = TcpStream::connect(&config.addr).await?; + println!("connected to {}", config.addr); + + let (mut reader, mut writer) = stream.into_split(); + let mut state = ClientState::new( + config.version, + config.read_timeout_secs, + config.default_poll_secs, + ); + + match config.mode { + QueryMode::Reset => { + send_reset_query(&mut writer, config.version).await?; + println!("sent Reset Query"); + } + QueryMode::Serial { session_id, serial } => { + state.session_id = Some(session_id); + state.serial = Some(serial); + send_serial_query(&mut writer, config.version, session_id, serial).await?; + println!("sent Serial Query"); + } + } + + state.schedule_next_poll(); + println!(); + + let stdin = tokio::io::stdin(); + let mut stdin_lines = BufReader::new(stdin).lines(); + + loop { + let poll_sleep = tokio::time::sleep_until(state.next_poll_deadline); + tokio::pin!(poll_sleep); + + tokio::select! { + line = stdin_lines.next_line() => { + match line { + Ok(Some(line)) => { + let should_quit = handle_console_command( + &line, + &mut writer, + &mut state, + ).await?; + + if should_quit { + println!("quit requested, closing client."); + break; + } + } + Ok(None) => { + println!("stdin closed, continue network loop."); + } + Err(err) => { + eprintln!("read stdin failed: {}", err); + } + } + } + + _ = &mut poll_sleep => { + handle_poll_tick(&mut writer, &mut state).await?; + state.schedule_next_poll(); + } + + read_result = timeout( + Duration::from_secs(state.read_timeout_secs), + read_pdu(&mut reader) + ) => { + match read_result { + Ok(Ok(pdu)) => { + print_pdu(&pdu.header, &pdu.body); + handle_incoming_pdu(&mut writer, &mut state, &pdu.header, &pdu.body).await?; + } + Ok(Err(err)) => { + eprintln!("read PDU failed: {}", err); + return Err(err); + } + Err(_) => { + println!( + "[timeout] no PDU received in {}s, connection kept open.", + state.read_timeout_secs + ); + } + } + } + } + } + + Ok(()) +} + +async fn handle_incoming_pdu( + writer: &mut OwnedWriteHalf, + state: &mut ClientState, + header: &PduHeader, + body: &[u8], +) -> io::Result<()> { + match header.pdu_type() { + PduType::CacheResponse => { + state.current_session_id = Some(header.session_id()); + } + + PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::Aspa => { + if state.current_session_id.is_none() { + state.current_session_id = Some(header.session_id()); + } + } + + PduType::EndOfData => { + let session_id = header.session_id(); + let eod = parse_end_of_data_info(body); + + state.session_id = Some(session_id); + state.current_session_id = Some(session_id); + + println!(); + + if let Some(eod) = eod { + state.serial = Some(eod.serial); + state.refresh = eod.refresh; + state.retry = eod.retry; + state.expire = eod.expire; + + println!( + "updated client state: session_id={}, serial={}", + session_id, eod.serial + ); + + if let Some(refresh) = eod.refresh { + println!("refresh : {}", refresh); + } + if let Some(retry) = eod.retry { + println!("retry : {}", retry); + } + if let Some(expire) = eod.expire { + println!("expire : {}", expire); + } + + state.schedule_next_poll(); + println!( + "next auto poll scheduled after {}s", + state.effective_poll_secs() + ); + } else { + println!( + "updated client state: session_id={}, serial=", + session_id + ); + } + + println!("received EndOfData, keep connection open."); + println!(); + } + + PduType::SerialNotify => { + let notify_session_id = header.session_id(); + let notify_serial = parse_serial_notify_serial(body); + + println!(); + + match (state.session_id, state.serial, notify_serial) { + (Some(current_session_id), Some(current_serial), Some(_new_serial)) + if current_session_id == notify_session_id => + { + println!( + "received Serial Notify for current session {}, send Serial Query with serial {}", + current_session_id, current_serial + ); + send_serial_query( + writer, + state.version, + current_session_id, + current_serial, + ) + .await?; + } + + _ => { + println!( + "received Serial Notify but local session/serial state is not usable, send Reset Query" + ); + send_reset_query(writer, state.version).await?; + } + } + + state.schedule_next_poll(); + println!(); + } + + PduType::CacheReset => { + println!(); + println!("received Cache Reset, send Reset Query"); + state.current_session_id = None; + state.serial = None; + send_reset_query(writer, state.version).await?; + state.schedule_next_poll(); + println!(); + } + + PduType::ErrorReport => { + println!(); + println!("received Error Report, keep connection open for debugging."); + if let Some(retry) = state.retry { + println!("will keep auto polling; server retry hint currently stored: {}s", retry); + } + println!(); + } + + PduType::SerialQuery | PduType::ResetQuery | PduType::Unknown(_) => { + // only print, no extra action + } + } + + Ok(()) +} + +async fn handle_poll_tick( + writer: &mut OwnedWriteHalf, + state: &mut ClientState, +) -> io::Result<()> { + println!(); + println!( + "[auto-poll] timer fired (interval={}s)", + state.effective_poll_secs() + ); + + match (state.session_id, state.serial) { + (Some(session_id), Some(serial)) => { + println!( + "[auto-poll] send Serial Query with session_id={}, serial={}", + session_id, serial + ); + send_serial_query(writer, state.version, session_id, serial).await?; + } + _ => { + println!("[auto-poll] local state incomplete, send Reset Query"); + send_reset_query(writer, state.version).await?; + } + } + + println!(); + Ok(()) +} + +async fn handle_console_command( + line: &str, + writer: &mut OwnedWriteHalf, + state: &mut ClientState, +) -> io::Result { + let line = line.trim(); + + if line.is_empty() { + return Ok(false); + } + + let parts: Vec<&str> = line.split_whitespace().collect(); + + match parts.as_slice() { + ["help"] => { + print_help(); + } + + ["state"] => { + print_state(state); + } + + ["reset"] => { + println!("manual command: send Reset Query"); + send_reset_query(writer, state.version).await?; + state.schedule_next_poll(); + } + + ["serial"] => { + match (state.session_id, state.serial) { + (Some(session_id), Some(serial)) => { + println!( + "manual command: send Serial Query with current state: session_id={}, serial={}", + session_id, serial + ); + send_serial_query(writer, state.version, session_id, serial).await?; + state.schedule_next_poll(); + } + _ => { + println!( + "manual command failed: current session_id/serial not available, use `reset` or `serial `" + ); + } + } + } + + ["serial", session_id, serial] => { + let session_id = match session_id.parse::() { + Ok(v) => v, + Err(err) => { + println!("invalid session_id: {}", err); + return Ok(false); + } + }; + + let serial = match serial.parse::() { + Ok(v) => v, + Err(err) => { + println!("invalid serial: {}", err); + return Ok(false); + } + }; + + println!( + "manual command: send Serial Query with explicit args: session_id={}, serial={}", + session_id, serial + ); + state.session_id = Some(session_id); + state.serial = Some(serial); + send_serial_query(writer, state.version, session_id, serial).await?; + state.schedule_next_poll(); + } + + ["timeout"] => { + println!("current read timeout: {}s", state.read_timeout_secs); + } + + ["timeout", secs] => { + let secs = match secs.parse::() { + Ok(v) if v > 0 => v, + Ok(_) => { + println!("timeout must be > 0"); + return Ok(false); + } + Err(err) => { + println!("invalid timeout seconds: {}", err); + return Ok(false); + } + }; + + state.read_timeout_secs = secs; + println!("updated read timeout to {}s", state.read_timeout_secs); + } + + ["poll"] => { + println!( + "current effective poll interval: {}s", + state.effective_poll_secs() + ); + println!("stored refresh hint : {:?}", state.refresh); + println!("default poll interval : {}s", state.default_poll_secs); + } + + ["poll", secs] => { + let secs = match secs.parse::() { + Ok(v) if v > 0 => v, + Ok(_) => { + println!("poll interval must be > 0"); + return Ok(false); + } + Err(err) => { + println!("invalid poll seconds: {}", err); + return Ok(false); + } + }; + + state.refresh = Some(secs as u32); + state.schedule_next_poll(); + println!("updated poll interval to {}s", secs); + } + + ["quit"] | ["exit"] => { + return Ok(true); + } + + _ => { + println!("unknown command: {}", line); + print_help(); + } + } + + Ok(false) +} + +fn print_help() { + println!("available commands:"); + println!(" help show this help"); + println!(" state print current client state"); + println!(" reset send Reset Query"); + println!(" serial send Serial Query with current session_id/serial"); + println!(" serial send Serial Query with explicit values"); + println!(" timeout show current read timeout"); + println!(" timeout update read timeout seconds"); + println!(" poll show current poll interval"); + println!(" poll override poll interval seconds"); + println!(" quit exit client"); + println!(); +} + +fn print_state(state: &ClientState) { + println!("client state:"); + println!(" version : {}", state.version); + println!(" session_id : {:?}", state.session_id); + println!(" serial : {:?}", state.serial); + println!(" current_session_id : {:?}", state.current_session_id); + println!(" refresh : {:?}", state.refresh); + println!(" retry : {:?}", state.retry); + println!(" expire : {:?}", state.expire); + println!(" read_timeout_secs : {}", state.read_timeout_secs); + println!(" default_poll_secs : {}", state.default_poll_secs); + println!(" effective_poll_secs: {}", state.effective_poll_secs()); + println!(); +} + +#[derive(Debug)] +struct ClientState { + version: u8, + session_id: Option, + serial: Option, + current_session_id: Option, + + refresh: Option, + retry: Option, + expire: Option, + + read_timeout_secs: u64, + default_poll_secs: u64, + next_poll_deadline: Instant, +} + +impl ClientState { + fn new(version: u8, read_timeout_secs: u64, default_poll_secs: u64) -> Self { + Self { + version, + session_id: None, + serial: None, + current_session_id: None, + refresh: None, + retry: None, + expire: None, + read_timeout_secs, + default_poll_secs, + next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs), + } + } + + fn effective_poll_secs(&self) -> u64 { + self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs) + } + + fn schedule_next_poll(&mut self) { + self.next_poll_deadline = + Instant::now() + Duration::from_secs(self.effective_poll_secs()); + } +} + +#[derive(Debug)] +struct Config { + addr: String, + version: u8, + mode: QueryMode, + read_timeout_secs: u64, + default_poll_secs: u64, +} + +impl Config { + fn from_args() -> io::Result { + let mut args = env::args().skip(1); + + let addr = args + .next() + .unwrap_or_else(|| "127.0.0.1:3323".to_string()); + + let version = args + .next() + .map(|s| { + s.parse::().map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("invalid version '{}': {}", s, e), + ) + }) + }) + .transpose()? + .unwrap_or(1); + + if version > 2 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unsupported RTR version {}, expected 0..=2", version), + )); + } + + let mode = match args.next().as_deref() { + None | Some("reset") => QueryMode::Reset, + Some("serial") => { + let session_id = args + .next() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "serial mode requires session_id and serial", + ) + })? + .parse::() + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("invalid session_id: {}", e), + ) + })?; + + let serial = args + .next() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "serial mode requires serial", + ) + })? + .parse::() + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("invalid serial: {}", e), + ) + })?; + + QueryMode::Serial { session_id, serial } + } + Some(other) => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("invalid mode '{}', expected 'reset' or 'serial'", other), + )); + } + }; + + Ok(Self { + addr, + version, + mode, + read_timeout_secs: DEFAULT_READ_TIMEOUT_SECS, + default_poll_secs: DEFAULT_POLL_INTERVAL_SECS, + }) + } +} \ No newline at end of file diff --git a/src/bin/rtr_debug_client/pretty.rs b/src/bin/rtr_debug_client/pretty.rs new file mode 100644 index 0000000..3de9202 --- /dev/null +++ b/src/bin/rtr_debug_client/pretty.rs @@ -0,0 +1,303 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; + +use crate::protocol::{ + flag_meaning, hex_bytes, PduHeader, PduType, ASPA_FIXED_BODY_LEN, + END_OF_DATA_V0_BODY_LEN, END_OF_DATA_V1_BODY_LEN, IPV4_PREFIX_BODY_LEN, + IPV6_PREFIX_BODY_LEN, ROUTER_KEY_FIXED_BODY_LEN, +}; + +pub fn print_pdu(header: &PduHeader, body: &[u8]) { + println!("--------------------------------------------------"); + println!("PDU: {}", header.pdu_type()); + println!("version : {}", header.version); + println!("length : {}", header.length); + + match header.pdu_type() { + PduType::CacheResponse => { + println!("session_id : {}", header.session_id()); + } + PduType::CacheReset => { + println!("cache reset"); + } + PduType::Ipv4Prefix => { + print_ipv4_prefix(header, body); + } + PduType::Ipv6Prefix => { + print_ipv6_prefix(header, body); + } + PduType::EndOfData => { + print_end_of_data(header, body); + } + PduType::ErrorReport => { + print_error_report(header, body); + } + PduType::SerialNotify => { + print_serial_notify(header, body); + } + PduType::SerialQuery => { + print_serial_query(header, body); + } + PduType::Aspa => { + print_aspa(header, body); + } + PduType::ResetQuery => { + println!("reset query"); + } + PduType::Unknown(_) => { + println!("field1 : {}", header.field1); + println!("body : {}", hex_bytes(body)); + } + } +} + +fn print_ipv4_prefix(header: &PduHeader, body: &[u8]) { + if body.len() != IPV4_PREFIX_BODY_LEN { + println!("invalid IPv4 Prefix body length: {}", body.len()); + println!("raw body: {}", hex_bytes(body)); + return; + } + + let flags = body[0]; + let prefix_len = body[1]; + let max_len = body[2]; + let zero = body[3]; + let prefix = Ipv4Addr::new(body[4], body[5], body[6], body[7]); + let asn = u32::from_be_bytes([body[8], body[9], body[10], body[11]]); + + println!("session_id : {}", header.session_id()); + println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags)); + println!("prefix_len : {}", prefix_len); + println!("max_len : {}", max_len); + println!("zero : {}", zero); + println!("prefix : {}", prefix); + println!("asn : {}", asn); +} + +fn print_ipv6_prefix(header: &PduHeader, body: &[u8]) { + if body.len() != IPV6_PREFIX_BODY_LEN { + println!("invalid IPv6 Prefix body length: {}", body.len()); + println!("raw body: {}", hex_bytes(body)); + return; + } + + let flags = body[0]; + let prefix_len = body[1]; + let max_len = body[2]; + let zero = body[3]; + + let mut addr = [0u8; 16]; + addr.copy_from_slice(&body[4..20]); + let prefix = Ipv6Addr::from(addr); + + let asn = u32::from_be_bytes([body[20], body[21], body[22], body[23]]); + + println!("session_id : {}", header.session_id()); + println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags)); + println!("prefix_len : {}", prefix_len); + println!("max_len : {}", max_len); + println!("zero : {}", zero); + println!("prefix : {}", prefix); + println!("asn : {}", asn); +} + +fn print_end_of_data(header: &PduHeader, body: &[u8]) { + println!("session_id : {}", header.session_id()); + + match body.len() { + END_OF_DATA_V0_BODY_LEN => { + let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]); + println!("serial : {}", serial); + println!("variant : v0"); + } + END_OF_DATA_V1_BODY_LEN => { + let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]); + let refresh = u32::from_be_bytes([body[4], body[5], body[6], body[7]]); + let retry = u32::from_be_bytes([body[8], body[9], body[10], body[11]]); + let expire = u32::from_be_bytes([body[12], body[13], body[14], body[15]]); + println!("serial : {}", serial); + println!("refresh : {}", refresh); + println!("retry : {}", retry); + println!("expire : {}", expire); + println!("variant : v1/v2"); + } + _ => { + println!("invalid EndOfData body length: {}", body.len()); + println!("raw body : {}", hex_bytes(body)); + } + } +} + +fn print_error_report(header: &PduHeader, body: &[u8]) { + println!("error_code : {}", header.error_code()); + + if body.len() < 8 { + println!("invalid ErrorReport body length: {}", body.len()); + println!("raw body : {}", hex_bytes(body)); + return; + } + + let encapsulated_len = + u32::from_be_bytes([body[0], body[1], body[2], body[3]]) as usize; + + if body.len() < 4 + encapsulated_len + 4 { + println!("invalid ErrorReport: truncated encapsulated PDU"); + println!("raw body : {}", hex_bytes(body)); + return; + } + + let encapsulated = &body[4..4 + encapsulated_len]; + + let text_len_offset = 4 + encapsulated_len; + let text_len = u32::from_be_bytes([ + body[text_len_offset], + body[text_len_offset + 1], + body[text_len_offset + 2], + body[text_len_offset + 3], + ]) as usize; + + if body.len() < text_len_offset + 4 + text_len { + println!("invalid ErrorReport: truncated text"); + println!("raw body : {}", hex_bytes(body)); + return; + } + + let text_bytes = &body[text_len_offset + 4..text_len_offset + 4 + text_len]; + let text = String::from_utf8_lossy(text_bytes); + + println!("encap_len : {}", encapsulated_len); + println!("encap_pdu : {}", hex_bytes(encapsulated)); + println!("text_len : {}", text_len); + println!("text : {}", text); +} + +fn print_serial_notify(header: &PduHeader, body: &[u8]) { + if body.len() != 4 { + println!("invalid Serial Notify body length: {}", body.len()); + println!("raw body : {}", hex_bytes(body)); + return; + } + + let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]); + println!("session_id : {}", header.session_id()); + println!("serial : {}", serial); +} + +fn print_serial_query(header: &PduHeader, body: &[u8]) { + if body.len() != 4 { + println!("invalid Serial Query body length: {}", body.len()); + println!("raw body : {}", hex_bytes(body)); + return; + } + + let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]); + println!("session_id : {}", header.session_id()); + println!("serial : {}", serial); +} + +#[allow(dead_code)] +fn print_router_key(header: &PduHeader, body: &[u8]) { + println!("session_id : {}", header.session_id()); + + if body.len() < ROUTER_KEY_FIXED_BODY_LEN { + println!("invalid Router Key body length: {}", body.len()); + println!("raw body : {}", hex_bytes(body)); + return; + } + + let flags = body[0]; + let zero = body[1]; + let ski = &body[2..22]; + let asn = u32::from_be_bytes([body[22], body[23], body[24], body[25]]); + let spki = &body[26..]; + + println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags)); + println!("zero : {}", zero); + println!("ski : {}", hex_bytes(ski)); + println!("asn : {}", asn); + println!("spki_len : {}", spki.len()); + println!("spki : {}", hex_bytes(spki)); +} + +fn print_aspa(header: &PduHeader, body: &[u8]) { + println!("session_id : {}", header.session_id()); + + if body.len() < ASPA_FIXED_BODY_LEN { + println!("invalid ASPA body length: {}", body.len()); + println!("raw body : {}", hex_bytes(body)); + return; + } + + let flags = body[0]; + let zero1 = body[1]; + let zero2 = body[2]; + let zero3 = body[3]; + let customer_asn = u32::from_be_bytes([body[4], body[5], body[6], body[7]]); + + println!("flags : 0x{:02x} ({})", flags, flag_meaning(flags)); + println!("reserved : [{}, {}, {}]", zero1, zero2, zero3); + println!("customer_as : {}", customer_asn); + + let providers_raw = &body[8..]; + if providers_raw.len() % 4 != 0 { + println!("invalid ASPA providers length: {}", providers_raw.len()); + println!("providers : {}", hex_bytes(providers_raw)); + return; + } + + let mut providers = Vec::new(); + for chunk in providers_raw.chunks_exact(4) { + providers.push(u32::from_be_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); + } + + println!("providers : {:?}", providers); +} + +pub fn parse_serial_notify_serial(body: &[u8]) -> Option { + if body.len() != 4 { + return None; + } + + Some(u32::from_be_bytes([body[0], body[1], body[2], body[3]])) +} + +pub fn parse_end_of_data_serial(body: &[u8]) -> Option { + match body.len() { + 4 | 16 => Some(u32::from_be_bytes([body[0], body[1], body[2], body[3]])), + _ => None, + } +} + +#[derive(Debug, Clone, Copy)] +pub struct EndOfDataInfo { + pub serial: u32, + pub refresh: Option, + pub retry: Option, + pub expire: Option, +} + +pub fn parse_end_of_data_info(body: &[u8]) -> Option { + match body.len() { + 4 => { + let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]); + Some(EndOfDataInfo { + serial, + refresh: None, + retry: None, + expire: None, + }) + } + 16 => { + let serial = u32::from_be_bytes([body[0], body[1], body[2], body[3]]); + let refresh = u32::from_be_bytes([body[4], body[5], body[6], body[7]]); + let retry = u32::from_be_bytes([body[8], body[9], body[10], body[11]]); + let expire = u32::from_be_bytes([body[12], body[13], body[14], body[15]]); + Some(EndOfDataInfo { + serial, + refresh: Some(refresh), + retry: Some(retry), + expire: Some(expire), + }) + } + _ => None, + } +} diff --git a/src/bin/rtr_debug_client/protocol.rs b/src/bin/rtr_debug_client/protocol.rs new file mode 100644 index 0000000..d2e6e37 --- /dev/null +++ b/src/bin/rtr_debug_client/protocol.rs @@ -0,0 +1,155 @@ +use std::fmt; + +pub const HEADER_LEN: usize = 8; +pub const SERIAL_QUERY_LEN: usize = 12; +pub const MAX_PDU_LEN: u32 = 1024 * 1024; // 1 MiB + +pub const IPV4_PREFIX_BODY_LEN: usize = 12; +pub const IPV6_PREFIX_BODY_LEN: usize = 24; +pub const END_OF_DATA_V0_BODY_LEN: usize = 4; +pub const END_OF_DATA_V1_BODY_LEN: usize = 16; +pub const ROUTER_KEY_FIXED_BODY_LEN: usize = 26; +pub const ASPA_FIXED_BODY_LEN: usize = 8; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QueryMode { + Reset, + Serial { session_id: u16, serial: u32 }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PduType { + SerialNotify, + SerialQuery, + ResetQuery, + CacheResponse, + Ipv4Prefix, + Ipv6Prefix, + EndOfData, + CacheReset, + ErrorReport, + Aspa, + Unknown(u8), +} + +impl PduType { + pub fn code(self) -> u8 { + match self { + Self::SerialNotify => 0, + Self::SerialQuery => 1, + Self::ResetQuery => 2, + Self::CacheResponse => 3, + Self::Ipv4Prefix => 4, + Self::Ipv6Prefix => 6, + Self::EndOfData => 7, + Self::CacheReset => 8, + Self::ErrorReport => 10, + Self::Aspa => 11, + Self::Unknown(v) => v, + } + } + + pub fn name(self) -> &'static str { + match self { + Self::SerialNotify => "Serial Notify", + Self::SerialQuery => "Serial Query", + Self::ResetQuery => "Reset Query", + Self::CacheResponse => "Cache Response", + Self::Ipv4Prefix => "IPv4 Prefix", + Self::Ipv6Prefix => "IPv6 Prefix", + Self::EndOfData => "End of Data", + Self::CacheReset => "Cache Reset", + Self::ErrorReport => "Error Report", + Self::Aspa => "ASPA", + Self::Unknown(_) => "Unknown", + } + } +} + +impl From for PduType { + fn from(value: u8) -> Self { + match value { + 0 => Self::SerialNotify, + 1 => Self::SerialQuery, + 2 => Self::ResetQuery, + 3 => Self::CacheResponse, + 4 => Self::Ipv4Prefix, + 6 => Self::Ipv6Prefix, + 7 => Self::EndOfData, + 8 => Self::CacheReset, + 10 => Self::ErrorReport, + 11 => Self::Aspa, + x => Self::Unknown(x), + } + } +} + +impl fmt::Display for PduType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Unknown(v) => write!(f, "{} ({})", self.name(), v), + _ => write!(f, "{} ({})", self.name(), self.code()), + } + } +} + +#[derive(Debug, Clone)] +pub struct PduHeader { + pub version: u8, + pub pdu_type_raw: u8, + pub field1: u16, + pub length: u32, +} + +impl PduHeader { + pub fn from_bytes(buf: [u8; HEADER_LEN]) -> Self { + Self { + version: buf[0], + pdu_type_raw: buf[1], + field1: u16::from_be_bytes([buf[2], buf[3]]), + length: u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]), + } + } + + pub fn pdu_type(&self) -> PduType { + self.pdu_type_raw.into() + } + + pub fn session_id(&self) -> u16 { + self.field1 + } + + pub fn error_code(&self) -> u16 { + self.field1 + } +} + +#[derive(Debug, Clone)] +pub struct RawPdu { + pub header: PduHeader, + pub body: Vec, +} + +pub fn flag_meaning(flags: u8) -> &'static str { + if flags & 0x01 == 0x01 { + "announcement" + } else { + "withdrawal" + } +} + +pub fn hex_bytes(data: &[u8]) -> String { + if data.is_empty() { + return "".to_string(); + } + + let mut out = String::with_capacity(data.len() * 3 - 1); + for (idx, b) in data.iter().enumerate() { + if idx > 0 { + out.push(' '); + } + use std::fmt::Write as _; + let _ = write!(out, "{:02x}", b); + } + out +} \ No newline at end of file diff --git a/src/bin/rtr_debug_client/wire.rs b/src/bin/rtr_debug_client/wire.rs new file mode 100644 index 0000000..346cb01 --- /dev/null +++ b/src/bin/rtr_debug_client/wire.rs @@ -0,0 +1,81 @@ +use std::io; + +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::protocol::{ + PduHeader, PduType, RawPdu, HEADER_LEN, MAX_PDU_LEN, SERIAL_QUERY_LEN, +}; + +pub async fn send_reset_query(stream: &mut S, version: u8) -> io::Result<()> +where + S: AsyncWrite + Unpin, +{ + let mut buf = [0u8; HEADER_LEN]; + buf[0] = version; + buf[1] = PduType::ResetQuery.code(); + buf[2..4].copy_from_slice(&0u16.to_be_bytes()); + buf[4..8].copy_from_slice(&(HEADER_LEN as u32).to_be_bytes()); + stream.write_all(&buf).await?; + stream.flush().await +} + +pub async fn send_serial_query( + stream: &mut S, + version: u8, + session_id: u16, + serial: u32, +) -> io::Result<()> +where + S: AsyncWrite + Unpin, +{ + let mut buf = [0u8; SERIAL_QUERY_LEN]; + buf[0] = version; + buf[1] = PduType::SerialQuery.code(); + buf[2..4].copy_from_slice(&session_id.to_be_bytes()); + buf[4..8].copy_from_slice(&(SERIAL_QUERY_LEN as u32).to_be_bytes()); + buf[8..12].copy_from_slice(&serial.to_be_bytes()); + stream.write_all(&buf).await?; + stream.flush().await +} + +pub async fn read_header(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + let mut buf = [0u8; HEADER_LEN]; + stream.read_exact(&mut buf).await?; + Ok(PduHeader::from_bytes(buf)) +} + +pub async fn read_pdu(stream: &mut S) -> io::Result +where + S: AsyncRead + Unpin, +{ + let header = read_header(stream).await?; + + if header.length < HEADER_LEN as u32 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "invalid PDU length {} < {}", + header.length, HEADER_LEN + ), + )); + } + + if header.length > MAX_PDU_LEN { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "PDU length {} exceeds max allowed {}", + header.length, MAX_PDU_LEN + ), + )); + } + + let body_len = header.length as usize - HEADER_LEN; + let mut body = vec![0u8; body_len]; + stream.read_exact(&mut body).await?; + + Ok(RawPdu { header, body }) +} \ No newline at end of file diff --git a/src/data_model/resources/as_resources.rs b/src/data_model/resources/as_resources.rs index fc49f4d..6ac7ef9 100644 --- a/src/data_model/resources/as_resources.rs +++ b/src/data_model/resources/as_resources.rs @@ -1,3 +1,4 @@ +use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq)] pub struct ASIdentifiers { @@ -54,7 +55,7 @@ impl ASRange { } } -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default)] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default, Serialize, Deserialize)] pub struct Asn(u32); impl Asn { diff --git a/src/data_model/resources/ip_resources.rs b/src/data_model/resources/ip_resources.rs index efc0250..a419aaa 100644 --- a/src/data_model/resources/ip_resources.rs +++ b/src/data_model/resources/ip_resources.rs @@ -1,71 +1,157 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; +use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct IPAddrBlocks { - pub ips: Vec + pub ips: Vec, } - // IP Address Family -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct IPAddressFamily { pub address_family: Afi, pub ip_address_choice: IPAddressChoice, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)] pub enum Afi { Ipv4, Ipv6, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum IPAddressChoice { Inherit, AddressOrRange(Vec), } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum IPAddressOrRange { AddressPrefix(IPAddressPrefix), AddressRange(IPAddressRange), } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)] pub struct IPAddressPrefix { pub address: IPAddress, pub prefix_length: u8, } -#[derive(Debug, Clone, PartialEq, Eq)] +impl IPAddressPrefix { + pub fn new(address: IPAddress, prefix_length: u8) -> Self { + Self { + address, + prefix_length, + } + } + + pub fn is_ipv4(&self) -> bool { + self.address.is_ipv4() + } + + pub fn is_ipv6(&self) -> bool { + self.address.is_ipv6() + } + + pub fn afi(&self) -> Afi { + self.address.afi() + } + + pub fn address(&self) -> IPAddress { + self.address + } + + pub fn prefix_length(&self) -> u8 { + self.prefix_length + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)] pub struct IPAddressRange { pub min: IPAddress, pub max: IPAddress, } -use std::net::{Ipv4Addr, Ipv6Addr}; - -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct IPAddress(u128); - -impl IPAddress { - pub fn to_ipv4(self) -> Option { - if self.0 <= u32::MAX as u128 { - Some(Ipv4Addr::from(self.0 as u32)) - } else { - None - } +impl IPAddressRange { + pub fn new(min: IPAddress, max: IPAddress) -> Self { + Self { min, max } } - pub fn to_ipv6(self) -> Ipv6Addr { - Ipv6Addr::from(self.0) + pub fn min(&self) -> IPAddress { + self.min } - pub fn is_ipv4(self) -> bool { - self.0 <= u32::MAX as u128 - } - - pub fn as_u128(self) -> u128 { - self.0 + pub fn max(&self) -> IPAddress { + self.max } } +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize, Deserialize)] +pub enum IPAddress { + V4(Ipv4Addr), + V6(Ipv6Addr), +} + +impl IPAddress { + pub fn from_ipv4(addr: Ipv4Addr) -> Self { + Self::V4(addr) + } + + pub fn from_ipv6(addr: Ipv6Addr) -> Self { + Self::V6(addr) + } + + pub fn to_ipv4(self) -> Option { + match self { + Self::V4(addr) => Some(addr), + Self::V6(_) => None, + } + } + + pub fn to_ipv6(self) -> Option { + match self { + Self::V4(_) => None, + Self::V6(addr) => Some(addr), + } + } + + pub fn is_ipv4(self) -> bool { + matches!(self, Self::V4(_)) + } + + pub fn is_ipv6(self) -> bool { + matches!(self, Self::V6(_)) + } + + pub fn afi(self) -> Afi { + match self { + Self::V4(_) => Afi::Ipv4, + Self::V6(_) => Afi::Ipv6, + } + } + + /// Returns the numeric address value. + /// + /// For IPv4, this is the 32-bit address widened to u128. + /// For IPv6, this is the full 128-bit address. + pub fn as_u128(self) -> u128 { + match self { + Self::V4(addr) => u32::from(addr) as u128, + Self::V6(addr) => u128::from(addr), + } + } + + pub fn as_v4_u32(self) -> Option { + match self { + Self::V4(addr) => Some(u32::from(addr)), + Self::V6(_) => None, + } + } + + pub fn as_v6_u128(self) -> Option { + match self { + Self::V4(_) => None, + Self::V6(addr) => Some(u128::from(addr)), + } + } +} \ No newline at end of file diff --git a/src/data_model/resources/mod.rs b/src/data_model/resources/mod.rs index 46a9e49..5b8b049 100644 --- a/src/data_model/resources/mod.rs +++ b/src/data_model/resources/mod.rs @@ -1,3 +1,3 @@ -pub(crate) mod ip_resources; -pub(crate) mod as_resources; +pub mod ip_resources; +pub mod as_resources; pub mod resource; diff --git a/src/lib.rs b/src/lib.rs index 72d8e7e..5ef2acb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,3 @@ pub mod data_model; mod slurm; -mod rtr; +pub mod rtr; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..1eebd79 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,218 @@ +use std::net::SocketAddr; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +use anyhow::{anyhow, Result}; +use tokio::task::JoinHandle; +use tracing::{info, warn}; + +use rpki::rtr::cache::{RtrCache, SharedRtrCache}; +use rpki::rtr::loader::load_vrps_from_file; +use rpki::rtr::payload::Timing; +use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService}; +use rpki::rtr::store_db::RtrStore; + +#[derive(Debug, Clone)] +struct AppConfig { + enable_tls: bool, + tcp_addr: SocketAddr, + tls_addr: SocketAddr, + + db_path: String, + vrp_file: String, + tls_cert_path: String, + tls_key_path: String, + + max_delta: u8, + refresh_interval: Duration, + + service_config: RtrServiceConfig, +} + +impl Default for AppConfig { + fn default() -> Self { + Self { + enable_tls: false, + tcp_addr: "0.0.0.0:3323".parse().expect("invalid default tcp_addr"), + tls_addr: "0.0.0.0:3324".parse().expect("invalid default tls_addr"), + + db_path: "./rtr-db".to_string(), + vrp_file: r"C:\Users\xuxiu\git_code\rpki\data\vrps.txt".to_string(), + tls_cert_path: "./certs/server.crt".to_string(), + tls_key_path: "./certs/server.key".to_string(), + + max_delta: 100, + refresh_interval: Duration::from_secs(300), + + service_config: RtrServiceConfig { + max_connections: 512, + notify_queue_size: 1024, + }, + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + init_tracing(); + + let config = AppConfig::default(); + log_startup_config(&config); + + let store = open_store(&config)?; + let shared_cache = init_shared_cache(&config, &store)?; + + let service = RtrService::with_config(shared_cache.clone(), config.service_config.clone()); + let notifier = service.notifier(); + + let running = start_servers(&config, &service); + let refresh_task = spawn_refresh_task(&config, shared_cache.clone(), store.clone(), notifier); + + wait_for_shutdown().await?; + + running.shutdown(); + running.wait().await; + + refresh_task.abort(); + let _ = refresh_task.await; + + info!("RTR service stopped"); + Ok(()) +} + +fn open_store(config: &AppConfig) -> Result { + info!("opening RTR store: {}", config.db_path); + RtrStore::open(&config.db_path) +} + +fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result { + let initial_cache = RtrCache::default().init( + store, + config.max_delta, + Timing::default(), + || load_vrps_from_file(&config.vrp_file), + )?; + + let shared_cache: SharedRtrCache = Arc::new(RwLock::new(initial_cache)); + + { + let cache = shared_cache + .read() + .map_err(|_| anyhow!("cache read lock poisoned during startup"))?; + + info!( + "cache initialized: session_id={}, serial={}", + cache.session_id(), + cache.serial() + ); + } + + Ok(shared_cache) +} + +fn start_servers(config: &AppConfig, service: &RtrService) -> RunningRtrService { + if config.enable_tls { + info!("starting TCP and TLS RTR servers"); + service.spawn_tcp_and_tls_from_pem( + config.tcp_addr, + config.tls_addr, + &config.tls_cert_path, + &config.tls_key_path, + ) + } else { + info!("starting TCP RTR server"); + service.spawn_tcp_only(config.tcp_addr) + } +} + +fn spawn_refresh_task( + config: &AppConfig, + shared_cache: SharedRtrCache, + store: RtrStore, + notifier: RtrNotifier, +) -> JoinHandle<()> { + let refresh_interval = config.refresh_interval; + let vrp_file = config.vrp_file.clone(); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(refresh_interval); + + loop { + interval.tick().await; + + match load_vrps_from_file(&vrp_file) { + Ok(payloads) => { + let updated = { + let mut cache = match shared_cache.write() { + Ok(guard) => guard, + Err(_) => { + warn!("cache write lock poisoned during refresh"); + continue; + } + }; + + let old_serial = cache.serial(); + + match cache.update(payloads, &store) { + Ok(()) => cache.serial() != old_serial, + Err(err) => { + warn!("RTR cache update failed: {:?}", err); + false + } + } + }; + + if updated { + notifier.notify_cache_updated(); + info!("RTR cache updated, serial notify broadcast sent"); + } + } + Err(err) => { + warn!("failed to reload VRPs from file {}: {:?}", vrp_file, err); + } + } + } + }) +} + +async fn wait_for_shutdown() -> Result<()> { + tokio::signal::ctrl_c().await?; + info!("shutdown signal received"); + Ok(()) +} + +fn log_startup_config(config: &AppConfig) { + info!("starting RTR service"); + info!("db_path={}", config.db_path); + info!("tcp_addr={}", config.tcp_addr); + info!("tls_enabled={}", config.enable_tls); + + if config.enable_tls { + info!("tls_addr={}", config.tls_addr); + info!("tls_cert_path={}", config.tls_cert_path); + info!("tls_key_path={}", config.tls_key_path); + } + + info!("vrp_file={}", config.vrp_file); + info!("max_delta={}", config.max_delta); + info!( + "refresh_interval_secs={}", + config.refresh_interval.as_secs() + ); + info!( + "max_connections={}", + config.service_config.max_connections + ); + info!( + "notify_queue_size={}", + config.service_config.notify_queue_size + ); +} + +fn init_tracing() { + let _ = tracing_subscriber::fmt() + .with_target(true) + .with_thread_ids(true) + .with_level(true) + .try_init(); +} \ No newline at end of file diff --git a/src/rtr/cache.rs b/src/rtr/cache.rs index 29619b7..7d2b115 100644 --- a/src/rtr/cache.rs +++ b/src/rtr/cache.rs @@ -1,16 +1,22 @@ +use std::cmp::Ordering; use std::collections::{BTreeSet, VecDeque}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; use chrono::{DateTime, NaiveDateTime, Utc}; use serde::{Deserialize, Serialize}; - -use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey}; +use sha2::{Digest, Sha256}; +use crate::data_model::resources::ip_resources::IPAddress; +use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Timing}; use crate::rtr::store_db::RtrStore; const DEFAULT_RETRY_INTERVAL: Duration = Duration::from_secs(600); const DEFAULT_EXPIRE_INTERVAL: Duration = Duration::from_secs(7200); + + +pub type SharedRtrCache = Arc>; + #[derive(Debug, Clone)] pub struct DualTime { instant: Instant, @@ -79,6 +85,11 @@ pub struct Snapshot { router_keys: BTreeSet, aspas: BTreeSet, created_at: DualTime, + + origins_hash: [u8; 32], + router_keys_hash: [u8; 32], + aspas_hash: [u8; 32], + snapshot_hash: [u8; 32], } impl Snapshot { @@ -87,12 +98,18 @@ impl Snapshot { router_keys: BTreeSet, aspas: BTreeSet, ) -> Self { - Snapshot { + let mut snapshot = Snapshot { origins, router_keys, aspas, created_at: DualTime::now(), - } + origins_hash: [0u8; 32], + router_keys_hash: [0u8; 32], + aspas_hash: [0u8; 32], + snapshot_hash: [0u8; 32], + }; + snapshot.recompute_hashes(); + snapshot } pub fn empty() -> Self { @@ -118,37 +135,85 @@ impl Snapshot { } } - Snapshot { - origins, - router_keys, - aspas, - created_at: DualTime::now(), + Snapshot::new(origins, router_keys, aspas) + } + + pub fn recompute_hashes(&mut self) { + self.origins_hash = self.compute_origins_hash(); + self.router_keys_hash = self.compute_router_keys_hash(); + self.aspas_hash = self.compute_aspas_hash(); + self.snapshot_hash = self.compute_snapshot_hash(); + } + + fn compute_origins_hash(&self) -> [u8; 32] { + Self::hash_ordered_iter(self.origins.iter()) + } + + fn compute_router_keys_hash(&self) -> [u8; 32] { + Self::hash_ordered_iter(self.router_keys.iter()) + } + + fn compute_aspas_hash(&self) -> [u8; 32] { + Self::hash_ordered_iter(self.aspas.iter()) + } + + fn compute_snapshot_hash(&self) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(b"snapshot:v1"); + hasher.update(self.origins_hash); + hasher.update(self.router_keys_hash); + hasher.update(self.aspas_hash); + hasher.finalize().into() + } + + fn hash_ordered_iter<'a, T, I>(iter: I) -> [u8; 32] + where + T: Serialize + 'a, + I: IntoIterator, + { + let mut hasher = Sha256::new(); + hasher.update(b"set:v1"); + + for item in iter { + let encoded = + serde_json::to_vec(item).expect("serialize snapshot item for hashing failed"); + let len = (encoded.len() as u32).to_be_bytes(); + hasher.update(len); + hasher.update(encoded); } + + hasher.finalize().into() } pub fn diff(&self, new_snapshot: &Snapshot) -> (Vec, Vec) { let mut announced = Vec::new(); let mut withdrawn = Vec::new(); - for origin in new_snapshot.origins.difference(&self.origins) { - announced.push(Payload::RouteOrigin(origin.clone())); - } - for origin in self.origins.difference(&new_snapshot.origins) { - withdrawn.push(Payload::RouteOrigin(origin.clone())); + if !self.same_origins(new_snapshot) { + for origin in new_snapshot.origins.difference(&self.origins) { + announced.push(Payload::RouteOrigin(origin.clone())); + } + for origin in self.origins.difference(&new_snapshot.origins) { + withdrawn.push(Payload::RouteOrigin(origin.clone())); + } } - for key in new_snapshot.router_keys.difference(&self.router_keys) { - announced.push(Payload::RouterKey(key.clone())); - } - for key in self.router_keys.difference(&new_snapshot.router_keys) { - withdrawn.push(Payload::RouterKey(key.clone())); + if !self.same_router_keys(new_snapshot) { + for key in new_snapshot.router_keys.difference(&self.router_keys) { + announced.push(Payload::RouterKey(key.clone())); + } + for key in self.router_keys.difference(&new_snapshot.router_keys) { + withdrawn.push(Payload::RouterKey(key.clone())); + } } - for aspa in new_snapshot.aspas.difference(&self.aspas) { - announced.push(Payload::Aspa(aspa.clone())); - } - for aspa in self.aspas.difference(&new_snapshot.aspas) { - withdrawn.push(Payload::Aspa(aspa.clone())); + if !self.same_aspas(new_snapshot) { + for aspa in new_snapshot.aspas.difference(&self.aspas) { + announced.push(Payload::Aspa(aspa.clone())); + } + for aspa in self.aspas.difference(&new_snapshot.aspas) { + withdrawn.push(Payload::Aspa(aspa.clone())); + } } (announced, withdrawn) @@ -169,6 +234,66 @@ impl Snapshot { v } + + /// Payloads sorted for RTR full snapshot sending. + /// Snapshot represents current valid state, so all payloads are treated as announcements. + pub fn payloads_for_rtr(&self) -> Vec { + let mut payloads = self.payloads(); + sort_payloads_for_rtr(&mut payloads, true); + payloads + } + + pub fn origins_hash(&self) -> [u8; 32] { + self.origins_hash + } + + pub fn router_keys_hash(&self) -> [u8; 32] { + self.router_keys_hash + } + + pub fn aspas_hash(&self) -> [u8; 32] { + self.aspas_hash + } + + pub fn snapshot_hash(&self) -> [u8; 32] { + self.snapshot_hash + } + + pub fn same_origins(&self, other: &Self) -> bool { + self.origins_hash == other.origins_hash + } + + pub fn same_router_keys(&self, other: &Self) -> bool { + self.router_keys_hash == other.router_keys_hash + } + + pub fn same_aspas(&self, other: &Self) -> bool { + self.aspas_hash == other.aspas_hash + } + + pub fn same_content(&self, other: &Self) -> bool { + self.snapshot_hash == other.snapshot_hash + } + + pub fn origins(&self) -> &BTreeSet { + &self.origins + } + + pub fn router_keys(&self) -> &BTreeSet { + &self.router_keys + } + + pub fn aspas(&self) -> &BTreeSet { + &self.aspas + } +} + +impl Snapshot { + pub fn is_empty(&self) -> bool { + self.origins.is_empty() + && self.router_keys.is_empty() + && self.aspas.is_empty() + } } #[derive(Debug, Serialize, Deserialize)] @@ -180,7 +305,10 @@ pub struct Delta { } impl Delta { - pub fn new(serial: u32, announced: Vec, withdrawn: Vec) -> Self { + pub fn new(serial: u32, mut announced: Vec, mut withdrawn: Vec) -> Self { + sort_payloads_for_rtr(&mut announced, true); + sort_payloads_for_rtr(&mut withdrawn, false); + Delta { serial, announced, @@ -201,8 +329,8 @@ impl Delta { &self.withdrawn } - pub fn created_at(self) -> DualTime { - self.created_at + pub fn created_at(&self) -> DualTime { + self.created_at.clone() } } @@ -219,7 +347,7 @@ pub struct RtrCache { // Max number of deltas to keep. max_delta: u8, // Refresh interval. - refresh_interval: Duration, + timing: Timing, // Last update begin time. last_update_begin: DualTime, // Last update end time. @@ -237,7 +365,7 @@ impl Default for RtrCache { snapshot: Snapshot::empty(), deltas: VecDeque::with_capacity(100), max_delta: 100, - refresh_interval: Duration::from_secs(600), + timing: Timing::default(), last_update_begin: now.clone(), last_update_end: now.clone(), created_at: now, @@ -248,9 +376,10 @@ impl Default for RtrCache { pub struct RtrCacheBuilder { session_id: Option, max_delta: Option, - refresh_interval: Option, + timing: Option, serial: Option, snapshot: Option, + deltas: Option>>, created_at: Option, } @@ -259,9 +388,10 @@ impl RtrCacheBuilder { Self { session_id: None, max_delta: None, - refresh_interval: None, + timing: None, serial: None, snapshot: None, + deltas: None, created_at: None, } } @@ -276,8 +406,8 @@ impl RtrCacheBuilder { self } - pub fn refresh_interval(mut self, v: Duration) -> Self { - self.refresh_interval = Some(v); + pub fn timing(mut self, v: Timing) -> Self { + self.timing = Some(v); self } @@ -291,6 +421,11 @@ impl RtrCacheBuilder { self } + pub fn deltas(mut self, v: VecDeque>) -> Self { + self.deltas = Some(v); + self + } + pub fn created_at(mut self, v: DualTime) -> Self { self.created_at = Some(v); self @@ -299,8 +434,12 @@ impl RtrCacheBuilder { pub fn build(self) -> RtrCache { let now = DualTime::now(); let max_delta = self.max_delta.unwrap_or(100); - let refresh_interval = self.refresh_interval.unwrap_or(Duration::from_secs(600)); + let timing = self.timing.unwrap_or_default(); let snapshot = self.snapshot.unwrap_or_else(Snapshot::empty); + let deltas = self + .deltas + .unwrap_or_else(|| VecDeque::with_capacity(max_delta.into())); + let serial = self.serial.unwrap_or(0); let created_at = self.created_at.unwrap_or_else(|| now.clone()); let session_id = self.session_id.unwrap_or_else(rand::random); @@ -309,9 +448,9 @@ impl RtrCacheBuilder { session_id, serial, snapshot, - deltas: VecDeque::with_capacity(max_delta.into()), + deltas, max_delta, - refresh_interval, + timing, last_update_begin: now.clone(), last_update_end: now, created_at, @@ -325,42 +464,44 @@ impl RtrCache { self, store: &RtrStore, max_delta: u8, - refresh_interval: Duration, + timing: Timing, file_loader: impl Fn() -> anyhow::Result>, ) -> anyhow::Result { - let snapshot = store.get_snapshot()?; - let session_id = store.get_session_id()?; - let serial = store.get_serial()?; - - if let (Some(snapshot), Some(session_id), Some(serial)) = - (snapshot, session_id, serial) - { - let mut cache = RtrCacheBuilder::new() - .session_id(session_id) - .max_delta(max_delta) - .refresh_interval(refresh_interval) - .serial(serial) - .snapshot(snapshot) - .build(); - - if let Some((min_serial, _max_serial)) = store.get_delta_window()? { - let deltas = store.load_deltas_since(min_serial.wrapping_sub(1))?; - for delta in deltas { - cache.push_delta(Arc::new(delta)); - } - } - + if let Some(cache) = Self::try_restore_from_store(store, max_delta, timing)? { + tracing::info!( + "RTR cache restored from store: session_id={}, serial={}", + self.session_id, + self.serial + ); return Ok(cache); } + tracing::warn!("RTR cache store unavailable or invalid, fallback to file loader"); + let payloads = file_loader()?; let snapshot = Snapshot::from_payloads(payloads); + + if snapshot.is_empty() { + anyhow::bail!("file loader returned an empty snapshot"); + } + + tracing::info!( + "RTR cache initialized from file loader: session_id={}, serial={}", + self.session_id, + self.serial + ); + let serial = 1; let session_id: u16 = rand::random(); + let snapshot_for_store = snapshot.clone(); + let snapshot_for_cache = snapshot.clone(); let store = store.clone(); + tokio::spawn(async move { - if let Err(e) = store.save_snapshot_and_meta(&snapshot, session_id, serial) { + if let Err(e) = + store.save_snapshot_and_meta(&snapshot_for_store, session_id, serial) + { tracing::error!("persist failed: {:?}", e); } }); @@ -368,10 +509,65 @@ impl RtrCache { Ok(RtrCacheBuilder::new() .session_id(session_id) .max_delta(max_delta) - .refresh_interval(refresh_interval) + .timing(timing) + .serial(serial) + .snapshot(snapshot_for_cache) + .build()) + } + + fn try_restore_from_store( + store: &RtrStore, + max_delta: u8, + timing: Timing, + ) -> anyhow::Result> { + let snapshot = store.get_snapshot()?; + let session_id = store.get_session_id()?; + let serial = store.get_serial()?; + + let (snapshot, session_id, serial) = match (snapshot, session_id, serial) { + (Some(snapshot), Some(session_id), Some(serial)) => (snapshot, session_id, serial), + _ => { + tracing::warn!("RTR cache store incomplete: snapshot/session_id/serial missing"); + return Ok(None); + } + }; + + if snapshot.is_empty() { + tracing::warn!("RTR cache store snapshot is empty, treat as unusable"); + return Ok(None); + } + + let mut cache = RtrCacheBuilder::new() + .session_id(session_id) + .max_delta(max_delta) + .timing(timing) .serial(serial) .snapshot(snapshot) - .build()) + .build(); + + match store.get_delta_window()? { + Some((min_serial, _max_serial)) => { + let deltas = match store.load_deltas_since(min_serial.wrapping_sub(1)) { + Ok(deltas) => deltas, + Err(err) => { + tracing::warn!( + "RTR cache store delta recovery failed, treat store as unusable: {:?}", + err + ); + return Ok(None); + } + }; + + for delta in deltas { + cache.push_delta(Arc::new(delta)); + } + } + None => { + tracing::info!("RTR cache store has no delta window, restore snapshot only"); + } + } + + Ok(Some(cache)) } fn next_serial(&mut self) -> u32 { @@ -431,10 +627,19 @@ impl RtrCache { new_payloads: Vec, store: &RtrStore, ) -> anyhow::Result<()> { + self.last_update_begin = DualTime::now(); + let new_snapshot = Snapshot::from_payloads(new_payloads); + + if self.snapshot.same_content(&new_snapshot) { + self.last_update_end = DualTime::now(); + return Ok(()); + } + let (announced, withdrawn) = self.snapshot.diff(&new_snapshot); if announced.is_empty() && withdrawn.is_empty() { + self.last_update_end = DualTime::now(); return Ok(()); } @@ -462,8 +667,8 @@ impl RtrCache { self.serial } - pub fn refresh_interval(&self) -> Duration { - self.refresh_interval + pub fn timing(&self) -> Timing { + self.timing } pub fn retry_interval(&self) -> Duration { @@ -477,6 +682,18 @@ impl RtrCache { pub fn current_snapshot(&self) -> (&Snapshot, u32, u16) { (&self.snapshot, self.serial, self.session_id) } + + pub fn last_update_begin(&self) -> DualTime { + self.last_update_begin.clone() + } + + pub fn last_update_end(&self) -> DualTime { + self.last_update_end.clone() + } + + pub fn created_at(&self) -> DualTime { + self.created_at.clone() + } } impl RtrCache { @@ -524,6 +741,7 @@ impl RtrCache { return SerialResult::UpToDate; } + let _ = newest_serial; SerialResult::Deltas(result) } } @@ -536,3 +754,176 @@ pub enum SerialResult { /// Delta window cannot cover; reset required. ResetRequired, } + +//------------ RTR ordering ------------------------------------------------- + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] +enum PayloadPduType { + Ipv4Prefix = 4, + Ipv6Prefix = 6, + RouterKey = 9, + Aspa = 11, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +enum RouteOriginKey { + V4 { + addr: u32, + plen: u8, + mlen: u8, + asn: u32, + }, + V6 { + addr: u128, + plen: u8, + mlen: u8, + asn: u32, + }, +} + +fn sort_payloads_for_rtr(payloads: &mut [Payload], announce: bool) { + payloads.sort_by(|a, b| compare_payload_for_rtr(a, b, announce)); +} + +fn compare_payload_for_rtr(a: &Payload, b: &Payload, announce: bool) -> Ordering { + let type_a = payload_pdu_type(a); + let type_b = payload_pdu_type(b); + + match type_a.cmp(&type_b) { + Ordering::Equal => {} + other => return other, + } + + match (a, b) { + (Payload::RouteOrigin(a), Payload::RouteOrigin(b)) => { + compare_route_origin_for_rtr(a, b, announce) + } + (Payload::RouterKey(a), Payload::RouterKey(b)) => { + compare_router_key_for_rtr(a, b) + } + (Payload::Aspa(a), Payload::Aspa(b)) => compare_aspa_for_rtr(a, b), + _ => Ordering::Equal, + } +} + +fn payload_pdu_type(payload: &Payload) -> PayloadPduType { + match payload { + Payload::RouteOrigin(ro) => { + if route_origin_is_ipv4(ro) { + PayloadPduType::Ipv4Prefix + } else { + PayloadPduType::Ipv6Prefix + } + } + Payload::RouterKey(_) => PayloadPduType::RouterKey, + Payload::Aspa(_) => PayloadPduType::Aspa, + } +} + +fn route_origin_is_ipv4(ro: &RouteOrigin) -> bool { + ro.prefix().address.is_ipv4() +} + +fn route_origin_key(ro: &RouteOrigin) -> RouteOriginKey { + let prefix = ro.prefix(); + let plen = prefix.prefix_length; + let mlen = ro.max_length(); + let asn = ro.asn().into_u32(); + + match prefix.address { + IPAddress::V4(addr) => { + RouteOriginKey::V4 { + addr: u32::from(addr), + plen, + mlen, + asn, + } + } + IPAddress::V6(addr) => { + RouteOriginKey::V6 { + addr: u128::from(addr), + plen, + mlen, + asn, + } + } + } +} + +fn compare_route_origin_for_rtr( + a: &RouteOrigin, + b: &RouteOrigin, + announce: bool, +) -> Ordering { + match (route_origin_key(a), route_origin_key(b)) { + ( + RouteOriginKey::V4 { + addr: addr_a, + plen: plen_a, + mlen: mlen_a, + asn: asn_a, + }, + RouteOriginKey::V4 { + addr: addr_b, + plen: plen_b, + mlen: mlen_b, + asn: asn_b, + }, + ) => { + if announce { + addr_b.cmp(&addr_a) + .then_with(|| mlen_b.cmp(&mlen_a)) + .then_with(|| plen_b.cmp(&plen_a)) + .then_with(|| asn_b.cmp(&asn_a)) + } else { + addr_a.cmp(&addr_b) + .then_with(|| mlen_a.cmp(&mlen_b)) + .then_with(|| plen_a.cmp(&plen_b)) + .then_with(|| asn_a.cmp(&asn_b)) + } + } + + ( + RouteOriginKey::V6 { + addr: addr_a, + plen: plen_a, + mlen: mlen_a, + asn: asn_a, + }, + RouteOriginKey::V6 { + addr: addr_b, + plen: plen_b, + mlen: mlen_b, + asn: asn_b, + }, + ) => { + if announce { + addr_b.cmp(&addr_a) + .then_with(|| mlen_b.cmp(&mlen_a)) + .then_with(|| plen_b.cmp(&plen_a)) + .then_with(|| asn_b.cmp(&asn_a)) + } else { + addr_a.cmp(&addr_b) + .then_with(|| mlen_a.cmp(&mlen_b)) + .then_with(|| plen_a.cmp(&plen_b)) + .then_with(|| asn_a.cmp(&asn_b)) + } + } + + _ => Ordering::Equal, + } +} + +fn compare_router_key_for_rtr(a: &RouterKey, b: &RouterKey) -> Ordering { + a.ski() + .cmp(&b.ski()) + .then_with(|| a.spki().len().cmp(&b.spki().len())) + .then_with(|| a.spki().cmp(b.spki())) + .then_with(|| a.asn().into_u32().cmp(&b.asn().into_u32())) +} + +fn compare_aspa_for_rtr(a: &Aspa, b: &Aspa) -> Ordering { + a.customer_asn() + .into_u32() + .cmp(&b.customer_asn().into_u32()) +} \ No newline at end of file diff --git a/src/rtr/loader.rs b/src/rtr/loader.rs new file mode 100644 index 0000000..35633b4 --- /dev/null +++ b/src/rtr/loader.rs @@ -0,0 +1,185 @@ +use std::fs; +use std::net::IpAddr; +use std::path::Path; +use std::str::FromStr; + +use anyhow::{anyhow, Context, Result}; + +use crate::data_model::resources::as_resources::Asn; +use crate::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; +use crate::rtr::payload::{Payload, RouteOrigin}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParsedVrp { + pub prefix_addr: IpAddr, + pub prefix_len: u8, + pub max_len: u8, + pub asn: u32, +} + +/// 从文本文件中加载 VRP,并转换成 RTR Payload::RouteOrigin。 +/// +/// 文件格式: +/// +/// ```text +/// # prefix,max_len,asn +/// 10.0.0.0/24,24,65001 +/// 10.0.1.0/24,24,65002 +/// 2001:db8::/32,48,65003 +/// ``` +pub fn load_vrps_from_file(path: impl AsRef) -> Result> { + let path = path.as_ref(); + + let content = fs::read_to_string(path) + .with_context(|| format!("failed to read VRP file: {}", path.display()))?; + + let mut payloads = Vec::new(); + + for (idx, raw_line) in content.lines().enumerate() { + let line_no = idx + 1; + let line = raw_line.trim(); + + if line.is_empty() || line.starts_with('#') { + continue; + } + + let vrp = parse_vrp_line(line) + .with_context(|| format!("invalid VRP line {}: {}", line_no, raw_line))?; + + payloads.push(Payload::RouteOrigin(build_route_origin(vrp)?)); + } + + Ok(payloads) +} + +/// 解析单行 VRP。 +/// +/// 格式: +/// `prefix/prefix_len,max_len,asn` +/// +/// 例如: +/// `10.0.0.0/24,24,65001` +pub fn parse_vrp_line(line: &str) -> Result { + let parts: Vec<_> = line.split(',').map(|s| s.trim()).collect(); + if parts.len() != 3 { + return Err(anyhow!( + "expected format: /,," + )); + } + + let prefix_part = parts[0]; + let max_len = u8::from_str(parts[1]) + .with_context(|| format!("invalid max_len: {}", parts[1]))?; + let asn = u32::from_str(parts[2]) + .with_context(|| format!("invalid asn: {}", parts[2]))?; + + let (addr_str, prefix_len_str) = prefix_part + .split_once('/') + .ok_or_else(|| anyhow!("prefix must be in CIDR form, e.g. 10.0.0.0/24"))?; + + let prefix_addr = IpAddr::from_str(addr_str.trim()) + .with_context(|| format!("invalid IP address: {}", addr_str))?; + + let prefix_len = u8::from_str(prefix_len_str.trim()) + .with_context(|| format!("invalid prefix length: {}", prefix_len_str))?; + + validate_vrp(prefix_addr, prefix_len, max_len)?; + + Ok(ParsedVrp { + prefix_addr, + prefix_len, + max_len, + asn, + }) +} + +fn validate_vrp(prefix_addr: IpAddr, prefix_len: u8, max_len: u8) -> Result<()> { + match prefix_addr { + IpAddr::V4(_) => { + if prefix_len > 32 { + return Err(anyhow!("IPv4 prefix length must be <= 32")); + } + if max_len > 32 { + return Err(anyhow!("IPv4 max_len must be <= 32")); + } + if max_len < prefix_len { + return Err(anyhow!("IPv4 max_len must be >= prefix length")); + } + } + IpAddr::V6(_) => { + if prefix_len > 128 { + return Err(anyhow!("IPv6 prefix length must be <= 128")); + } + if max_len > 128 { + return Err(anyhow!("IPv6 max_len must be <= 128")); + } + if max_len < prefix_len { + return Err(anyhow!("IPv6 max_len must be >= prefix length")); + } + } + } + Ok(()) +} + +pub fn build_route_origin(vrp: ParsedVrp) -> Result { + let address = match vrp.prefix_addr { + IpAddr::V4(addr) => IPAddress::from_ipv4(addr), + IpAddr::V6(addr) => IPAddress::from_ipv6(addr), + }; + + let prefix = IPAddressPrefix::new(address, vrp.prefix_len); + let asn = Asn::from(vrp.asn); + + Ok(RouteOrigin::new(prefix, vrp.max_len, asn)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_ipv4_vrp_line() { + let got = parse_vrp_line("10.0.0.0/24,24,65001").unwrap(); + assert_eq!( + got, + ParsedVrp { + prefix_addr: IpAddr::from_str("10.0.0.0").unwrap(), + prefix_len: 24, + max_len: 24, + asn: 65001, + } + ); + } + + #[test] + fn parse_ipv6_vrp_line() { + let got = parse_vrp_line("2001:db8::/32,48,65003").unwrap(); + assert_eq!( + got, + ParsedVrp { + prefix_addr: IpAddr::from_str("2001:db8::").unwrap(), + prefix_len: 32, + max_len: 48, + asn: 65003, + } + ); + } + + #[test] + fn parse_rejects_invalid_max_len() { + let err = parse_vrp_line("10.0.0.0/24,16,65001").unwrap_err(); + assert!(err.to_string().contains("max_len")); + } + + #[test] + fn parse_rejects_invalid_ip() { + let err = parse_vrp_line("10.0.0.999/24,24,65001").unwrap_err(); + assert!(err.to_string().contains("invalid IP")); + } + + #[test] + fn parse_rejects_invalid_format() { + let err = parse_vrp_line("10.0.0.0/24,24").unwrap_err(); + assert!(err.to_string().contains("expected format")); + } +} \ No newline at end of file diff --git a/src/rtr/mod.rs b/src/rtr/mod.rs index 7062221..df72a99 100644 --- a/src/rtr/mod.rs +++ b/src/rtr/mod.rs @@ -1,7 +1,9 @@ pub mod pdu; pub mod cache; pub mod payload; -mod store_db; -mod session; -mod error_type; -mod state; \ No newline at end of file +pub mod store_db; +pub mod session; +pub mod error_type; +pub mod state; +pub mod server; +pub mod loader; \ No newline at end of file diff --git a/src/rtr/payload.rs b/src/rtr/payload.rs index c80b271..42f24cc 100644 --- a/src/rtr/payload.rs +++ b/src/rtr/payload.rs @@ -1,13 +1,19 @@ use std::fmt::Debug; -use std::sync::Arc; use std::time::Duration; -use asn1_rs::nom::character::streaming::u64; use serde::{Deserialize, Serialize}; use crate::data_model::resources::as_resources::Asn; use crate::data_model::resources::ip_resources::IPAddressPrefix; -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] +enum PayloadPduType { + Ipv4Prefix = 4, + Ipv6Prefix = 6, + RouterKey = 9, + Aspa = 11, +} + +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] pub struct Ski([u8; 20]); @@ -19,6 +25,14 @@ pub struct RouteOrigin { } impl RouteOrigin { + pub fn new(prefix: IPAddressPrefix, max_length: u8, asn: Asn) -> Self { + Self { + prefix, + max_length, + asn, + } + } + pub fn prefix(&self) -> &IPAddressPrefix { &self.prefix } @@ -37,13 +51,29 @@ impl RouteOrigin { pub struct RouterKey { subject_key_identifier: Ski, asn: Asn, - subject_public_key_info: Arc<[u8]>, + subject_public_key_info: Vec, } impl RouterKey { + pub fn new(subject_key_identifier: Ski, asn: Asn, subject_public_key_info: Vec) -> Self { + Self { + subject_key_identifier, + asn, + subject_public_key_info, + } + } + + pub fn ski(&self) -> Ski { + self.subject_key_identifier + } + pub fn asn(&self) -> Asn { self.asn } + + pub fn spki(&self) -> &[u8] { + &self.subject_public_key_info + } } #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] @@ -53,6 +83,16 @@ pub struct Aspa { } impl Aspa { + pub fn new(customer_asn: Asn, mut provider_asns: Vec) -> Self { + provider_asns.sort(); + provider_asns.dedup(); + + Self { + customer_asn, + provider_asns, + } + } + pub fn customer_asn(&self) -> Asn { self.customer_asn } @@ -66,7 +106,7 @@ impl Aspa { #[derive(Clone, Debug, Serialize, Deserialize)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub enum Payload { - /// A route origin authorisation. + /// A route origin. RouteOrigin(RouteOrigin), /// A BGPsec router key. @@ -91,6 +131,10 @@ pub struct Timing { } impl Timing { + pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self { + Self { refresh, retry, expire } + } + pub fn refresh(self) -> Duration { Duration::from_secs(u64::from(self.refresh)) } @@ -104,3 +148,13 @@ impl Timing { } } + +impl Default for Timing { + fn default() -> Self { + Self { + refresh: 3600, + retry: 600, + expire: 7200, + } + } +} \ No newline at end of file diff --git a/src/rtr/pdu.rs b/src/rtr/pdu.rs index c83accd..7edc1e1 100644 --- a/src/rtr/pdu.rs +++ b/src/rtr/pdu.rs @@ -10,10 +10,11 @@ use anyhow::Result; use std::slice; use anyhow::bail; +use serde::Serialize; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -pub const HEADER_LEN: u32 = 8; +pub const HEADER_LEN: usize = 8; pub const MAX_PDU_LEN: u32 = 65535; pub const IPV4_PREFIX_LEN: u32 = 20; pub const IPV6_PREFIX_LEN: u32 = 32; @@ -78,6 +79,10 @@ macro_rules! concrete { self.header.session_id() } + pub fn pdu(&self) -> u8 { + self.header.pdu() + } + /// Returns the PDU size. /// /// The size is returned as a `u32` since that type is used in @@ -126,7 +131,7 @@ macro_rules! concrete { ) -> Result, io::Error> { let mut res = Self::default(); sock.read_exact(res.header.as_mut()).await?; - if res.header.pdu() == Error::PDU { + if res.header.pdu() == ErrorReport::PDU { // Since we should drop the session after an error, we // can safely ignore all the rest of the error for now. return Ok(Err(res.header)) @@ -183,7 +188,7 @@ macro_rules! concrete { // 所有PDU公共头部信息 #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct Header { version: u8, pdu: u8, @@ -203,36 +208,34 @@ impl Header { } } - pub async fn read(sock: &mut TcpStream) -> Result { + pub async fn read(sock: &mut S) -> Result { let mut buf = [0u8; HEADER_LEN]; - - // 1. 精确读取 8 字节 sock.read_exact(&mut buf).await?; - // 2. 手动解析(大端) let version = buf[0]; let pdu = buf[1]; - let reserved = u16::from_be_bytes([buf[2], buf[3]]); - let length = u32::from_be_bytes([ - buf[4], buf[5], buf[6], buf[7], - ]); + let session_id = u16::from_be_bytes([buf[2], buf[3]]); + let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); - // 3. 基础合法性校验 - - if length < HEADER_LEN{ - bail!("Invalid PDU length"); + if length < HEADER_LEN as u32 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid PDU length", + )); } - // 限制最大长度 if length > MAX_PDU_LEN { - bail!("PDU too large"); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "PDU too large", + )); } Ok(Self { version, pdu, - session_id: reserved, - length, + session_id: session_id.to_be(), + length: length.to_be(), }) } @@ -259,7 +262,7 @@ impl Header { common!(Header); #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct HeaderWithFlags { version: u8, pdu: u8, @@ -294,7 +297,7 @@ impl HeaderWithFlags { ]); // 3. 基础合法性校验 - if length < HEADER_LEN{ + if length < HEADER_LEN as u32{ bail!("Invalid PDU length"); } @@ -324,7 +327,7 @@ impl HeaderWithFlags { // Serial Notify #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct SerialNotify { header: Header, serial_number: u32, @@ -340,13 +343,17 @@ impl SerialNotify { } } + pub fn serial_number(self) -> u32 { + self.serial_number + } + } concrete!(SerialNotify); // Serial Query #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct SerialQuery { header: Header, serial_number: u32, @@ -372,7 +379,7 @@ concrete!(SerialQuery); // Reset Query #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct ResetQuery { header: Header } @@ -382,7 +389,7 @@ impl ResetQuery { pub fn new(version: u8) -> Self { ResetQuery { - header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN), + header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32), } } } @@ -392,7 +399,7 @@ concrete!(ResetQuery); // Cache Response #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct CacheResponse { header: Header, } @@ -402,7 +409,7 @@ impl CacheResponse { pub fn new(version: u8, session_id: u16) -> Self { CacheResponse { - header: Header::new(version, Self::PDU, session_id, HEADER_LEN), + header: Header::new(version, Self::PDU, session_id, HEADER_LEN as u32), } } } @@ -411,7 +418,7 @@ concrete!(CacheResponse); // Flags -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct Flags(u8); impl Flags { @@ -434,7 +441,7 @@ impl Flags { // IPv4 Prefix #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct IPv4Prefix { header: Header, @@ -479,7 +486,7 @@ concrete!(IPv4Prefix); // IPv6 Prefix #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct IPv6Prefix { header: Header, @@ -524,7 +531,7 @@ concrete!(IPv6Prefix); // End of Data -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize)] pub enum EndOfData { V0(EndOfDataV0), V1(EndOfDataV1), @@ -544,7 +551,7 @@ impl EndOfData { } #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct EndOfDataV0 { header: Header, @@ -567,7 +574,7 @@ concrete!(EndOfDataV0); #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct EndOfDataV1 { header: Header, @@ -605,7 +612,7 @@ concrete!(EndOfDataV1); // Cache Reset #[repr(C, packed)] -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct CacheReset { header: Header, } @@ -615,7 +622,7 @@ impl CacheReset { pub fn new(version: u8) -> Self{ CacheReset { - header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN) + header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32) } } } @@ -624,7 +631,7 @@ concrete!(CacheReset); // Error Report -#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct ErrorReport { octets: Vec, } @@ -703,7 +710,7 @@ impl ErrorReport { // TODO: 补全 // Router Key -#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct RouterKey { header: HeaderWithFlags, @@ -727,13 +734,13 @@ impl RouterKey { + 1 // flags // + self.ski.as_ref().len() + 4 // ASN - + self.subject_public_key_info.len() as u32; + + self.subject_public_key_info.len(); let header = HeaderWithFlags::new( self.header.version(), Self::PDU, self.flags, - length, + length as u32, ); w.write_all(&[ @@ -755,6 +762,7 @@ impl RouterKey { // ASPA +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct Aspa{ header: HeaderWithFlags, @@ -773,13 +781,13 @@ impl Aspa { let length = HEADER_LEN + 1 + 4 - + (self.provider_asns.len() as u32 * 4); + + (self.provider_asns.len() * 4); let header = HeaderWithFlags::new( self.header.version(), Self::PDU, Flags::new(self.header.flags), - length, + length as u32, ); w.write_all(&[ @@ -836,7 +844,7 @@ mod tests { assert_eq!(decoded.version(), 1); assert_eq!(decoded.session_id(), 42); - assert_eq!(decoded.serial_number, 100u32.to_be()); + assert_eq!(decoded.serial_number(), 100u32.to_be()); } #[tokio::test] diff --git a/src/rtr/server/config.rs b/src/rtr/server/config.rs new file mode 100644 index 0000000..af46fe1 --- /dev/null +++ b/src/rtr/server/config.rs @@ -0,0 +1,14 @@ +#[derive(Debug, Clone)] +pub struct RtrServiceConfig { + pub max_connections: usize, + pub notify_queue_size: usize, +} + +impl Default for RtrServiceConfig { + fn default() -> Self { + Self { + max_connections: 1024, + notify_queue_size: 1024, + } + } +} \ No newline at end of file diff --git a/src/rtr/server/connection.rs b/src/rtr/server/connection.rs new file mode 100644 index 0000000..54870f4 --- /dev/null +++ b/src/rtr/server/connection.rs @@ -0,0 +1,79 @@ +use std::net::SocketAddr; +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use anyhow::{Context, Result}; +use tokio::net::TcpStream; +use tokio::sync::{broadcast, watch, OwnedSemaphorePermit}; +use tracing::error; + +use tokio_rustls::TlsAcceptor; + +use crate::rtr::cache::SharedRtrCache; +use crate::rtr::session::RtrSession; + +pub struct ConnectionGuard { + active_connections: Arc, + _permit: OwnedSemaphorePermit, +} + +impl ConnectionGuard { + pub fn new( + active_connections: Arc, + permit: OwnedSemaphorePermit, + ) -> Self { + active_connections.fetch_add(1, Ordering::Relaxed); + Self { + active_connections, + _permit: permit, + } + } +} + +impl Drop for ConnectionGuard { + fn drop(&mut self) { + self.active_connections.fetch_sub(1, Ordering::Relaxed); + } +} + +pub async fn handle_tcp_connection( + cache: SharedRtrCache, + stream: TcpStream, + peer_addr: SocketAddr, + notify_rx: broadcast::Receiver<()>, + shutdown_rx: watch::Receiver, +) -> Result<()> { + let session = RtrSession::new(cache, stream, notify_rx, shutdown_rx); + + if let Err(err) = session.run().await { + error!("RTR TCP session run failed for {}: {:?}", peer_addr, err); + return Err(err); + } + + Ok(()) +} + +pub async fn handle_tls_connection( + cache: SharedRtrCache, + stream: TcpStream, + peer_addr: SocketAddr, + acceptor: TlsAcceptor, + notify_rx: broadcast::Receiver<()>, + shutdown_rx: watch::Receiver, +) -> Result<()> { + let tls_stream = acceptor + .accept(stream) + .await + .with_context(|| format!("TLS handshake failed for {}", peer_addr))?; + + let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx); + + if let Err(err) = session.run().await { + error!("RTR TLS session run failed for {}: {:?}", peer_addr, err); + return Err(err); + } + + Ok(()) +} \ No newline at end of file diff --git a/src/rtr/server/listener.rs b/src/rtr/server/listener.rs new file mode 100644 index 0000000..77c30dc --- /dev/null +++ b/src/rtr/server/listener.rs @@ -0,0 +1,219 @@ +use std::net::SocketAddr; +use std::path::Path; +use std::sync::{ + Arc, + atomic::AtomicUsize, +}; + +use anyhow::{Context, Result}; +use tokio::net::TcpListener; +use tokio::sync::{broadcast, watch, Semaphore}; +use tracing::{info, warn}; + +use rustls::ServerConfig; +use tokio_rustls::TlsAcceptor; + +use crate::rtr::cache::SharedRtrCache; +use crate::rtr::server::connection::{ConnectionGuard, handle_tcp_connection, handle_tls_connection}; +use crate::rtr::server::tls::load_rustls_server_config; + +pub struct RtrServer { + bind_addr: SocketAddr, + cache: SharedRtrCache, + notify_tx: broadcast::Sender<()>, + shutdown_tx: watch::Sender, + connection_limiter: Arc, + active_connections: Arc, +} + +impl RtrServer { + pub fn new( + bind_addr: SocketAddr, + cache: SharedRtrCache, + notify_tx: broadcast::Sender<()>, + shutdown_tx: watch::Sender, + connection_limiter: Arc, + active_connections: Arc, + ) -> Self { + Self { + bind_addr, + cache, + notify_tx, + shutdown_tx, + connection_limiter, + active_connections, + } + } + + pub fn bind_addr(&self) -> SocketAddr { + self.bind_addr + } + + pub fn cache(&self) -> SharedRtrCache { + self.cache.clone() + } + + pub fn active_connections(&self) -> usize { + self.active_connections.load(std::sync::atomic::Ordering::Relaxed) + } + + pub async fn run_tcp(self) -> Result<()> { + let listener = TcpListener::bind(self.bind_addr) + .await + .with_context(|| format!("failed to bind TCP RTR server on {}", self.bind_addr))?; + + let mut shutdown_rx = self.shutdown_tx.subscribe(); + + info!("RTR TCP server listening on {}", self.bind_addr); + + loop { + tokio::select! { + changed = shutdown_rx.changed() => { + match changed { + Ok(()) => { + if *shutdown_rx.borrow() { + info!("RTR TCP listener {} shutting down", self.bind_addr); + return Ok(()); + } + } + Err(_) => { + info!("RTR TCP listener {} shutdown channel closed", self.bind_addr); + return Ok(()); + } + } + } + + accept_res = listener.accept() => { + let (stream, peer_addr) = match accept_res { + Ok(v) => v, + Err(err) => { + warn!("RTR TCP accept failed: {}", err); + continue; + } + }; + + if let Err(err) = stream.set_nodelay(true) { + warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err); + } + + let permit = match self.connection_limiter.clone().try_acquire_owned() { + Ok(permit) => permit, + Err(_) => { + warn!( + "RTR TCP connection rejected for {}: max connections reached ({})", + peer_addr, + self.connection_limiter.available_permits() + ); + drop(stream); + continue; + } + }; + + let cache = self.cache.clone(); + let notify_rx = self.notify_tx.subscribe(); + let shutdown_rx = self.shutdown_tx.subscribe(); + let active_connections = self.active_connections.clone(); + + info!("RTR TCP client connected: {}", peer_addr); + + tokio::spawn(async move { + let _guard = ConnectionGuard::new(active_connections, permit); + if let Err(err) = + handle_tcp_connection(cache, stream, peer_addr, notify_rx, shutdown_rx).await + { + warn!("RTR TCP session {} ended with error: {:?}", peer_addr, err); + } else { + info!("RTR TCP session {} closed", peer_addr); + } + }); + } + } + } + } + + pub async fn run_tls_from_pem( + self, + cert_path: impl AsRef, + key_path: impl AsRef, + ) -> Result<()> { + let tls_config = Arc::new(load_rustls_server_config(cert_path, key_path)?); + self.run_tls(tls_config).await + } + + pub async fn run_tls(self, tls_config: Arc) -> Result<()> { + let listener = TcpListener::bind(self.bind_addr) + .await + .with_context(|| format!("failed to bind TLS RTR server on {}", self.bind_addr))?; + + let acceptor = TlsAcceptor::from(tls_config); + let mut shutdown_rx = self.shutdown_tx.subscribe(); + + info!("RTR TLS server listening on {}", self.bind_addr); + + loop { + tokio::select! { + changed = shutdown_rx.changed() => { + match changed { + Ok(()) => { + if *shutdown_rx.borrow() { + info!("RTR TLS listener {} shutting down", self.bind_addr); + return Ok(()); + } + } + Err(_) => { + info!("RTR TLS listener {} shutdown channel closed", self.bind_addr); + return Ok(()); + } + } + } + + accept_res = listener.accept() => { + let (stream, peer_addr) = match accept_res { + Ok(v) => v, + Err(err) => { + warn!("RTR TLS accept failed: {}", err); + continue; + } + }; + + if let Err(err) = stream.set_nodelay(true) { + warn!("failed to enable TCP_NODELAY for {}: {}", peer_addr, err); + } + + let permit = match self.connection_limiter.clone().try_acquire_owned() { + Ok(permit) => permit, + Err(_) => { + warn!("RTR TLS connection rejected for {}: max connections reached", peer_addr); + drop(stream); + continue; + } + }; + + let cache = self.cache.clone(); + let acceptor = acceptor.clone(); + let notify_rx = self.notify_tx.subscribe(); + let shutdown_rx = self.shutdown_tx.subscribe(); + let active_connections = self.active_connections.clone(); + + info!("RTR TLS client connected: {}", peer_addr); + + tokio::spawn(async move { + let _guard = ConnectionGuard::new(active_connections, permit); + if let Err(err) = handle_tls_connection( + cache, + stream, + peer_addr, + acceptor, + notify_rx, + shutdown_rx, + ).await { + warn!("RTR TLS session {} ended with error: {:?}", peer_addr, err); + } else { + info!("RTR TLS session {} closed", peer_addr); + } + }); + } + } + } + } +} \ No newline at end of file diff --git a/src/rtr/server/mod.rs b/src/rtr/server/mod.rs new file mode 100644 index 0000000..3d833af --- /dev/null +++ b/src/rtr/server/mod.rs @@ -0,0 +1,12 @@ +pub mod config; +pub mod connection; +pub mod listener; +pub mod notifier; +pub mod service; +pub mod tls; + +pub use config::RtrServiceConfig; +pub use listener::RtrServer; +pub use notifier::RtrNotifier; +pub use service::{RtrService, RunningRtrService}; +pub use tls::load_rustls_server_config; \ No newline at end of file diff --git a/src/rtr/server/notifier.rs b/src/rtr/server/notifier.rs new file mode 100644 index 0000000..208b18e --- /dev/null +++ b/src/rtr/server/notifier.rs @@ -0,0 +1,16 @@ +use tokio::sync::broadcast; + +#[derive(Clone)] +pub struct RtrNotifier { + tx: broadcast::Sender<()>, +} + +impl RtrNotifier { + pub fn new(tx: broadcast::Sender<()>) -> Self { + Self { tx } + } + + pub fn notify_cache_updated(&self) { + let _ = self.tx.send(()); + } +} \ No newline at end of file diff --git a/src/rtr/server/service.rs b/src/rtr/server/service.rs new file mode 100644 index 0000000..275a94a --- /dev/null +++ b/src/rtr/server/service.rs @@ -0,0 +1,154 @@ +use std::net::SocketAddr; +use std::path::Path; +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use tokio::sync::{broadcast, watch, Semaphore}; +use tokio::task::JoinHandle; +use tracing::error; + +use crate::rtr::cache::SharedRtrCache; +use crate::rtr::server::config::RtrServiceConfig; +use crate::rtr::server::listener::RtrServer; +use crate::rtr::server::notifier::RtrNotifier; + +pub struct RtrService { + cache: SharedRtrCache, + notify_tx: broadcast::Sender<()>, + shutdown_tx: watch::Sender, + connection_limiter: Arc, + active_connections: Arc, + config: RtrServiceConfig, +} + +impl RtrService { + pub fn new(cache: SharedRtrCache) -> Self { + Self::with_config(cache, RtrServiceConfig::default()) + } + + pub fn with_config(cache: SharedRtrCache, config: RtrServiceConfig) -> Self { + let (notify_tx, _) = broadcast::channel(config.notify_queue_size); + let (shutdown_tx, _) = watch::channel(false); + + Self { + cache, + notify_tx, + shutdown_tx, + connection_limiter: Arc::new(Semaphore::new(config.max_connections)), + active_connections: Arc::new(AtomicUsize::new(0)), + config, + } + } + + pub fn cache(&self) -> SharedRtrCache { + self.cache.clone() + } + + pub fn notifier(&self) -> RtrNotifier { + RtrNotifier::new(self.notify_tx.clone()) + } + + pub fn notify_cache_updated(&self) { + let _ = self.notify_tx.send(()); + } + + pub fn active_connections(&self) -> usize { + self.active_connections.load(Ordering::Relaxed) + } + + pub fn max_connections(&self) -> usize { + self.config.max_connections + } + + pub fn tcp_server(&self, bind_addr: SocketAddr) -> RtrServer { + RtrServer::new( + bind_addr, + self.cache.clone(), + self.notify_tx.clone(), + self.shutdown_tx.clone(), + self.connection_limiter.clone(), + self.active_connections.clone(), + ) + } + + pub fn tls_server(&self, bind_addr: SocketAddr) -> RtrServer { + RtrServer::new( + bind_addr, + self.cache.clone(), + self.notify_tx.clone(), + self.shutdown_tx.clone(), + self.connection_limiter.clone(), + self.active_connections.clone(), + ) + } + + pub fn spawn_tcp(&self, bind_addr: SocketAddr) -> JoinHandle<()> { + let server = self.tcp_server(bind_addr); + tokio::spawn(async move { + if let Err(err) = server.run_tcp().await { + error!("RTR TCP server {} exited with error: {:?}", bind_addr, err); + } + }) + } + + pub fn spawn_tls_from_pem( + &self, + bind_addr: SocketAddr, + cert_path: impl AsRef, + key_path: impl AsRef, + ) -> JoinHandle<()> { + let cert_path = cert_path.as_ref().to_path_buf(); + let key_path = key_path.as_ref().to_path_buf(); + let server = self.tls_server(bind_addr); + + tokio::spawn(async move { + if let Err(err) = server.run_tls_from_pem(cert_path, key_path).await { + error!("RTR TLS server {} exited with error: {:?}", bind_addr, err); + } + }) + } + + pub fn spawn_tcp_and_tls_from_pem( + &self, + tcp_bind_addr: SocketAddr, + tls_bind_addr: SocketAddr, + cert_path: impl AsRef, + key_path: impl AsRef, + ) -> RunningRtrService { + let tcp_handle = self.spawn_tcp(tcp_bind_addr); + let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path); + + RunningRtrService { + shutdown_tx: self.shutdown_tx.clone(), + handles: vec![tcp_handle, tls_handle], + } + } + + pub fn spawn_tcp_only(&self, tcp_bind_addr: SocketAddr) -> RunningRtrService { + let tcp_handle = self.spawn_tcp(tcp_bind_addr); + + RunningRtrService { + shutdown_tx: self.shutdown_tx.clone(), + handles: vec![tcp_handle], + } + } +} + +pub struct RunningRtrService { + shutdown_tx: watch::Sender, + handles: Vec>, +} + +impl RunningRtrService { + pub fn shutdown(&self) { + let _ = self.shutdown_tx.send(true); + } + + pub async fn wait(self) { + for handle in self.handles { + let _ = handle.await; + } + } +} \ No newline at end of file diff --git a/src/rtr/server/tls.rs b/src/rtr/server/tls.rs new file mode 100644 index 0000000..6e178bc --- /dev/null +++ b/src/rtr/server/tls.rs @@ -0,0 +1,52 @@ +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; + +use anyhow::{anyhow, Context, Result}; +use rustls::ServerConfig; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; + +pub fn load_rustls_server_config( + cert_path: impl AsRef, + key_path: impl AsRef, +) -> Result { + let cert_path: PathBuf = cert_path.as_ref().to_path_buf(); + let key_path: PathBuf = key_path.as_ref().to_path_buf(); + + let certs = load_certs(&cert_path) + .with_context(|| format!("failed to load certs from {}", cert_path.display()))?; + + let key = load_private_key(&key_path) + .with_context(|| format!("failed to load private key from {}", key_path.display()))?; + + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| anyhow!("invalid certificate/key pair: {}", e))?; + + Ok(config) +} + +fn load_certs(path: &Path) -> Result>> { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + + let certs = rustls_pemfile::certs(&mut reader) + .collect::, _>>()?; + + if certs.is_empty() { + return Err(anyhow!("no certificates found in {}", path.display())); + } + + Ok(certs) +} + +fn load_private_key(path: &Path) -> Result> { + let file = File::open(path)?; + let mut reader = BufReader::new(file); + + let key = rustls_pemfile::private_key(&mut reader)? + .ok_or_else(|| anyhow!("no private key found in {}", path.display()))?; + + Ok(key) +} \ No newline at end of file diff --git a/src/rtr/session.rs b/src/rtr/session.rs index 05892a3..500b42b 100644 --- a/src/rtr/session.rs +++ b/src/rtr/session.rs @@ -1,16 +1,16 @@ -use std::sync::Arc; - use anyhow::{bail, Result}; use tokio::io; -use tokio::net::TcpStream; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::{broadcast, watch}; use tracing::warn; -use crate::rtr::cache::{Delta, RtrCache, SerialResult}; +use crate::data_model::resources::ip_resources::IPAddress; +use crate::rtr::cache::{Delta, SerialResult, SharedRtrCache}; use crate::rtr::error_type::ErrorCode; -use crate::rtr::payload::{Payload, RouteOrigin, Timing}; +use crate::rtr::payload::{Payload, RouteOrigin}; use crate::rtr::pdu::{ CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, Header, IPv4Prefix, IPv6Prefix, - ResetQuery, SerialQuery, + ResetQuery, SerialNotify, SerialQuery, }; const SUPPORTED_MAX_VERSION: u8 = 2; @@ -26,59 +26,116 @@ enum SessionState { Closed, } -pub struct RtrSession { - cache: Arc, +pub struct RtrSession { + cache: SharedRtrCache, version: Option, - stream: TcpStream, + stream: S, state: SessionState, + notify_rx: broadcast::Receiver<()>, + shutdown_rx: watch::Receiver, } -impl RtrSession { - pub fn new(cache: Arc, stream: TcpStream) -> Self { +impl RtrSession +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new( + cache: SharedRtrCache, + stream: S, + notify_rx: broadcast::Receiver<()>, + shutdown_rx: watch::Receiver, + ) -> Self { Self { cache, version: None, stream, state: SessionState::Connected, + notify_rx, + shutdown_rx, } } pub async fn run(mut self) -> Result<()> { loop { - let header = match Header::read(&mut self.stream).await { - Ok(h) => h, - Err(_) => return Ok(()), - }; - - if self.version.is_none() { - self.negotiate_version(header.version()).await?; - } else if header.version() != self.version.unwrap() { - self.send_unsupported_version(self.version.unwrap()).await?; - bail!("version changed within session"); + tokio::select! { + changed = self.shutdown_rx.changed() => { + match changed { + Ok(()) => { + if *self.shutdown_rx.borrow() { + self.state = SessionState::Closed; + return Ok(()); + } + } + Err(_) => { + // shutdown sender dropped,按关闭处理 + self.state = SessionState::Closed; + return Ok(()); + } + } } - match header.pdu() { - ResetQuery::PDU => { - let _ = ResetQuery::read_payload(header, &mut self.stream).await?; - self.handle_reset_query().await?; - } - SerialQuery::PDU => { - let query = SerialQuery::read_payload(header, &mut self.stream).await?; - let session_id = query.session_id(); - let serial = u32::from_be(query.serial_number()); - self.handle_serial(session_id, serial).await?; - } - ErrorReport::PDU => { - let _ = ErrorReport::skip_payload(header, &mut self.stream).await; + header_res = Header::read(&mut self.stream) => { + let header = match header_res { + Ok(h) => h, + Err(_) => { + self.state = SessionState::Closed; + return Ok(()); + } + }; + + if self.version.is_none() { + self.negotiate_version(header.version()).await?; + } else if header.version() != self.version.unwrap() { + self.send_unsupported_version(self.version.unwrap()).await?; self.state = SessionState::Closed; - return Ok(()); + bail!("version changed within session"); } - _ => { - self.send_error(header.version(), ErrorCode::UnsupportedPduType, Some(&header), &[]) + + match header.pdu() { + ResetQuery::PDU => { + let _ = ResetQuery::read_payload(header, &mut self.stream).await?; + self.handle_reset_query().await?; + } + SerialQuery::PDU => { + let query = SerialQuery::read_payload(header, &mut self.stream).await?; + let session_id = query.session_id(); + let serial = u32::from_be(query.serial_number()); + self.handle_serial(session_id, serial).await?; + } + ErrorReport::PDU => { + let _ = ErrorReport::skip_payload(header, &mut self.stream).await; + self.state = SessionState::Closed; + return Ok(()); + } + _ => { + self.send_error( + header.version(), + ErrorCode::UnsupportedPduType, + Some(&header), + &[], + ) .await?; - return Ok(()); + self.state = SessionState::Closed; + return Ok(()); + } } } + + notify_res = self.notify_rx.recv(), + if self.state == SessionState::Established && self.version.is_some() => { + match notify_res { + Ok(()) => { + self.handle_notify().await?; + } + Err(broadcast::error::RecvError::Lagged(_)) => { + self.handle_notify().await?; + } + Err(broadcast::error::RecvError::Closed) => { + // notify 通道关闭,不影响已有会话,继续跑,真正关闭由 shutdown_rx 控制 + } + } + } + } } } @@ -110,38 +167,70 @@ impl RtrSession { &[], ErrorCode::UnsupportedProtocolVersion.description(), ) - .write(&mut self.stream) - .await + .write(&mut self.stream) + .await } async fn handle_reset_query(&mut self) -> Result<()> { + let (payloads, session_id, serial) = { + let cache = self + .cache + .read() + .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + let snapshot = cache.snapshot(); + let payloads = snapshot.payloads_for_rtr(); + let session_id = cache.session_id(); + let serial = cache.serial(); + (payloads, session_id, serial) + }; + + self.write_cache_response(session_id).await?; + self.send_payloads(&payloads, true).await?; + self.write_end_of_data(session_id, serial).await?; + self.state = SessionState::Established; - - let snapshot = self.cache.snapshot(); - self.write_cache_response().await?; - self.send_payloads(snapshot.payloads(), true).await?; - self.write_end_of_data(self.cache.session_id(), self.cache.serial()) - .await?; - Ok(()) } async fn handle_serial(&mut self, client_session: u16, client_serial: u32) -> Result<()> { - let current_session = self.cache.session_id(); - let current_serial = self.cache.serial(); + let serial_result = { + let cache = self + .cache + .read() + .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + cache.get_deltas_since(client_session, client_serial) + }; - match self.cache.get_deltas_since(client_session, client_serial) { + match serial_result { SerialResult::ResetRequired => { self.write_cache_reset().await?; + self.state = SessionState::Established; return Ok(()); } SerialResult::UpToDate => { + let (current_session, current_serial) = { + let cache = self + .cache + .read() + .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + (cache.session_id(), cache.serial()) + }; + self.write_end_of_data(current_session, current_serial) .await?; + self.state = SessionState::Established; return Ok(()); } SerialResult::Deltas(deltas) => { - self.write_cache_response().await?; + let (current_session, current_serial) = { + let cache = self + .cache + .read() + .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + (cache.session_id(), cache.serial()) + }; + + self.write_cache_response(current_session).await?; for delta in deltas { self.send_delta(&delta).await?; } @@ -154,32 +243,55 @@ impl RtrSession { Ok(()) } - async fn write_cache_response(&mut self) -> Result<()> { - let version = self.version.ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") - })?; + async fn handle_notify(&mut self) -> Result<()> { + let (session_id, serial) = { + let cache = self + .cache + .read() + .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + (cache.session_id(), cache.serial()) + }; - CacheResponse::new(version, self.cache.session_id()) + self.send_serial_notify(session_id, serial).await + } + + fn version(&self) -> io::Result { + self.version + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "version not negotiated")) + } + + async fn send_serial_notify(&mut self, session_id: u16, serial: u32) -> Result<()> { + let version = self.version()?; + SerialNotify::new(version, session_id, serial) + .write(&mut self.stream) + .await?; + Ok(()) + } + + async fn write_cache_response(&mut self, session_id: u16) -> Result<()> { + let version = self.version()?; + CacheResponse::new(version, session_id) .write(&mut self.stream) .await?; Ok(()) } async fn write_cache_reset(&mut self) -> Result<()> { - let version = self.version.ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") - })?; - + let version = self.version()?; CacheReset::new(version).write(&mut self.stream).await?; Ok(()) } async fn write_end_of_data(&mut self, session_id: u16, serial: u32) -> Result<()> { - let version = self.version.ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") - })?; + let version = self.version()?; + let timing = { + let cache = self + .cache + .read() + .map_err(|_| anyhow::anyhow!("cache read lock poisoned"))?; + cache.timing() + }; - let timing = self.timing(); let end = EndOfData::new(version, session_id, serial, timing); match end { EndOfData::V0(pdu) => pdu.write(&mut self.stream).await?, @@ -189,20 +301,20 @@ impl RtrSession { Ok(()) } - async fn send_payloads(&mut self, payloads: Vec, announce: bool) -> Result<()> { + async fn send_payloads(&mut self, payloads: &[Payload], announce: bool) -> Result<()> { for payload in payloads { - self.send_payload(&payload, announce).await?; + self.send_payload(payload, announce).await?; } Ok(()) } - async fn send_delta(&mut self, delta: &Arc) -> Result<()> { - for payload in delta.withdrawn() { - self.send_payload(payload, false).await?; - } + async fn send_delta(&mut self, delta: &Delta) -> Result<()> { for payload in delta.announced() { self.send_payload(payload, true).await?; } + for payload in delta.withdrawn() { + self.send_payload(payload, false).await?; + } Ok(()) } @@ -222,9 +334,7 @@ impl RtrSession { } async fn send_route_origin(&mut self, origin: &RouteOrigin, announce: bool) -> Result<()> { - let version = self.version.ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") - })?; + let version = self.version()?; let flags = Flags::new(if announce { ANNOUNCE_FLAG @@ -236,15 +346,17 @@ impl RtrSession { let prefix_len = prefix.prefix_length; let max_len = origin.max_length(); - if let Some(v4) = prefix.address.to_ipv4() { - IPv4Prefix::new(version, flags, prefix_len, max_len, v4, origin.asn()) - .write(&mut self.stream) - .await?; - } else { - let v6 = prefix.address.to_ipv6(); - IPv6Prefix::new(version, flags, prefix_len, max_len, v6, origin.asn()) - .write(&mut self.stream) - .await?; + match prefix.address { + IPAddress::V4(v4) => { + IPv4Prefix::new(version, flags, prefix_len, max_len, v4, origin.asn()) + .write(&mut self.stream) + .await?; + } + IPAddress::V6(v6) => { + IPv6Prefix::new(version, flags, prefix_len, max_len, v6, origin.asn()) + .write(&mut self.stream) + .await?; + } } Ok(()) @@ -257,24 +369,10 @@ impl RtrSession { offending_header: Option<&Header>, text: &[u8], ) -> io::Result<()> { - let offending = offending_header - .map(|h| h.as_ref()) - .unwrap_or(&[]); + let offending = offending_header.map(|h| h.as_ref()).unwrap_or(&[]); ErrorReport::new(version, code.as_u16(), offending, text) .write(&mut self.stream) .await } - - fn timing(&self) -> Timing { - let refresh = self.cache.refresh_interval().as_secs(); - let retry = self.cache.retry_interval().as_secs(); - let expire = self.cache.expire_interval().as_secs(); - - Timing { - refresh: refresh.min(u32::MAX as u64) as u32, - retry: retry.min(u32::MAX as u64) as u32, - expire: expire.min(u32::MAX as u64) as u32, - } - } -} +} \ No newline at end of file diff --git a/src/rtr/store_db.rs b/src/rtr/store_db.rs index e8f6400..e824032 100644 --- a/src/rtr/store_db.rs +++ b/src/rtr/store_db.rs @@ -274,20 +274,30 @@ impl RtrStore { } pub fn load_deltas_since(&self, serial: u32) -> Result> { - let cf_handle = self.db.cf_handle(CF_DELTA).ok_or_else(|| anyhow!("CF_DELTA not found"))?; - let mut out = Vec::new(); - let start_key = delta_key(serial.wrapping_add(1)); - - let iter = self + let cf_handle = self .db - .iterator_cf(cf_handle, IteratorMode::From(&start_key, Direction::Forward)); + .cf_handle(CF_DELTA) + .ok_or_else(|| anyhow!("CF_DELTA not found"))?; + + let start_key = delta_key(serial.wrapping_add(1)); + let iter = self.db.iterator_cf( + cf_handle, + IteratorMode::From(&start_key, Direction::Forward), + ); + + let mut out = Vec::new(); + + for item in iter { + let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; + + let parsed = delta_key_serial(key.as_ref()) + .ok_or_else(|| anyhow!("Invalid delta key"))?; - for (key, value) in iter { - let parsed = delta_key_serial(&key).ok_or_else(|| anyhow!("Invalid delta key"))?; if parsed <= serial { continue; } - let delta: Delta = serde_json::from_slice(&value)?; + + let delta: Delta = serde_json::from_slice(value.as_ref())?; out.push(delta); } diff --git a/tests/common/mod.rs b/tests/common/mod.rs new file mode 100644 index 0000000..a7e5207 --- /dev/null +++ b/tests/common/mod.rs @@ -0,0 +1 @@ +pub mod test_helper; \ No newline at end of file diff --git a/tests/common/test_helper.rs b/tests/common/test_helper.rs new file mode 100644 index 0000000..9ac9cc3 --- /dev/null +++ b/tests/common/test_helper.rs @@ -0,0 +1,332 @@ +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::fmt::Write; + +use serde_json::{json, Value}; + +use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; +use rpki::rtr::payload::{Payload, RouteOrigin}; +use rpki::rtr::pdu::{CacheResponse, EndOfDataV1, IPv4Prefix, IPv6Prefix}; +use rpki::rtr::cache::SerialResult; + +pub struct RtrDebugDumper { + entries: Vec, +} + +impl RtrDebugDumper { + pub fn new() -> Self { + Self { entries: Vec::new() } + } + + pub fn push(&mut self, pdu: u8, body: &T) { + self.entries.push(json!({ + "pdu": pdu, + "pdu_name": pdu_type_name(pdu), + "body": body + })); + } + + pub fn push_value(&mut self, pdu: u8, body: Value) { + self.entries.push(json!({ + "pdu": pdu, + "pdu_name": pdu_type_name(pdu), + "body": body + })); + } + + pub fn print_pretty(&self, test_name: &str) { + println!( + "\n===== RTR Debug Dump: {} =====\n{}\n", + test_name, + serde_json::to_string_pretty(&self.entries).unwrap() + ); + } +} + +pub fn pdu_type_name(pdu: u8) -> &'static str { + match pdu { + 0 => "Serial Notify", + 1 => "Serial Query", + 2 => "Reset Query", + 3 => "Cache Response", + 4 => "IPv4 Prefix", + 6 => "IPv6 Prefix", + 7 => "End of Data", + 8 => "Cache Reset", + 9 => "Router Key", + 10 => "Error Report", + 11 => "ASPA", + 255 => "Reserved", + _ => "Unknown", + } +} + +pub fn dump_cache_response(resp: &CacheResponse) -> Value { + json!({ + "header": { + "version": resp.version(), + "pdu": resp.pdu(), + "session_id": resp.session_id(), + "length": 8 + } + }) +} + +pub fn dump_ipv4_prefix(p: &IPv4Prefix) -> Value { + json!({ + "header": { + "version": p.version(), + "pdu": p.pdu(), + "session_id": 0, + "length": 20 + }, + "flags": { + "raw": if p.flag().is_announce() { 1 } else { 0 }, + "announce": p.flag().is_announce() + }, + "prefix": p.prefix().to_string(), + "prefix_len": p.prefix_len(), + "max_len": p.max_len(), + "asn": p.asn().into_u32() + }) +} + +pub fn dump_ipv6_prefix(p: &IPv6Prefix) -> Value { + json!({ + "header": { + "version": p.version(), + "pdu": p.pdu(), + "session_id": 0, + "length": 32 + }, + "flags": { + "raw": if p.flag().is_announce() { 1 } else { 0 }, + "announce": p.flag().is_announce() + }, + "prefix": p.prefix().to_string(), + "prefix_len": p.prefix_len(), + "max_len": p.max_len(), + "asn": p.asn().into_u32() + }) +} + +pub fn dump_eod_v1(eod: &EndOfDataV1) -> Value { + let timing = eod.timing(); + json!({ + "header": { + "version": eod.version(), + "pdu": eod.pdu(), + "session_id": eod.session_id(), + "length": 24 + }, + "serial_number": eod.serial_number(), + "refresh_interval": timing.refresh, + "retry_interval": timing.retry, + "expire_interval": timing.expire + }) +} + +pub fn dump_cache_reset(version: u8, pdu: u8) -> Value { + json!({ + "header": { + "version": version, + "pdu": pdu, + "session_id": 0, + "length": 8 + } + }) +} + +pub fn v4_prefix(a: u8, b: u8, c: u8, d: u8, prefix_len: u8) -> IPAddressPrefix { + IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(a, b, c, d)), + prefix_length: prefix_len, + } +} + +pub fn v6_prefix(addr: Ipv6Addr, prefix_len: u8) -> IPAddressPrefix { + IPAddressPrefix { + address: IPAddress::from_ipv6(addr), + prefix_length: prefix_len, + } +} + +pub fn v4_origin( + a: u8, + b: u8, + c: u8, + d: u8, + prefix_len: u8, + max_len: u8, + asn: u32, +) -> RouteOrigin { + let prefix = v4_prefix(a, b, c, d, prefix_len); + RouteOrigin::new(prefix, max_len, asn.into()) +} + +pub fn v6_origin( + addr: Ipv6Addr, + prefix_len: u8, + max_len: u8, + asn: u32, +) -> RouteOrigin { + let prefix = v6_prefix(addr, prefix_len); + RouteOrigin::new(prefix, max_len, asn.into()) +} + +pub fn as_route_origin(payload: &Payload) -> &RouteOrigin { + match payload { + Payload::RouteOrigin(ro) => ro, + _ => panic!("expected RouteOrigin payload"), + } +} + +pub fn as_v4_route_origin(payload: &Payload) -> &RouteOrigin { + let ro = as_route_origin(payload); + assert!(ro.prefix().address.is_ipv4(), "expected IPv4 RouteOrigin"); + ro +} + +pub fn as_v6_route_origin(payload: &Payload) -> &RouteOrigin { + let ro = as_route_origin(payload); + assert!(ro.prefix().address.is_ipv6(), "expected IPv6 RouteOrigin"); + ro +} + +pub fn route_origin_to_string(ro: &RouteOrigin) -> String { + let prefix = ro.prefix(); + let addr = match prefix.address { + IPAddress::V4(v4) => v4.to_string(), + IPAddress::V6(v6) => v6.to_string(), + }; + + format!( + "{}/{}-{} AS{}", + addr, + prefix.prefix_length, + ro.max_length(), + ro.asn().into_u32() + ) +} + +pub fn payload_to_string(payload: &Payload) -> String { + match payload { + Payload::RouteOrigin(ro) => format!("RouteOrigin({})", route_origin_to_string(ro)), + Payload::RouterKey(_) => "RouterKey(...)".to_string(), + Payload::Aspa(_) => "Aspa(...)".to_string(), + } +} + +pub fn payloads_to_pretty_lines(payloads: &[Payload]) -> String { + let mut out = String::new(); + for (idx, payload) in payloads.iter().enumerate() { + let _ = writeln!(&mut out, " [{}] {}", idx, payload_to_string(payload)); + } + out +} + +pub fn print_payloads(label: &str, payloads: &[Payload]) { + println!( + "\n===== {} =====\n{}", + label, + payloads_to_pretty_lines(payloads) + ); +} + +pub fn serial_result_to_string(result: &SerialResult) -> String { + match result { + SerialResult::UpToDate => "UpToDate".to_string(), + SerialResult::ResetRequired => "ResetRequired".to_string(), + SerialResult::Deltas(deltas) => { + let serials: Vec = deltas.iter().map(|d| d.serial()).collect(); + format!("Deltas {:?}", serials) + } + } +} + +pub fn print_serial_result(label: &str, result: &SerialResult) { + println!("\n===== {} =====\n{}\n", label, serial_result_to_string(result)); +} + +pub fn bytes_to_hex(bytes: &[u8]) -> String { + let mut out = String::with_capacity(bytes.len() * 2); + for b in bytes { + let _ = write!(&mut out, "{:02x}", b); + } + out +} + +pub fn print_snapshot_hashes(label: &str, snapshot: &rpki::rtr::cache::Snapshot) { + println!( + "\n===== {} =====\norigins_hash={}\nrouter_keys_hash={}\naspas_hash={}\nsnapshot_hash={}\n", + label, + bytes_to_hex(&snapshot.origins_hash()), + bytes_to_hex(&snapshot.router_keys_hash()), + bytes_to_hex(&snapshot.aspas_hash()), + bytes_to_hex(&snapshot.snapshot_hash()), + ); +} + +pub fn test_report( + name: &str, + purpose: &str, + input: &str, + output: &str, +) { + println!( + "\n==================== TEST REPORT ====================\n测试名称: {}\n测试目的: {}\n\n【输入】\n{}\n【输出】\n{}\n====================================================\n", + name, purpose, input, output + ); +} + +pub fn payloads_to_string(payloads: &[Payload]) -> String { + let mut out = String::new(); + for (idx, payload) in payloads.iter().enumerate() { + let _ = writeln!(&mut out, " [{}] {}", idx, payload_to_string(payload)); + } + if out.is_empty() { + out.push_str(" \n"); + } + out +} + +pub fn snapshot_hashes_to_string(snapshot: &rpki::rtr::cache::Snapshot) -> String { + format!( + " origins_hash: {}\n router_keys_hash: {}\n aspas_hash: {}\n snapshot_hash: {}\n", + bytes_to_hex(&snapshot.origins_hash()), + bytes_to_hex(&snapshot.router_keys_hash()), + bytes_to_hex(&snapshot.aspas_hash()), + bytes_to_hex(&snapshot.snapshot_hash()), + ) +} + +pub fn serial_result_detail_to_string(result: &rpki::rtr::cache::SerialResult) -> String { + match result { + rpki::rtr::cache::SerialResult::UpToDate => { + " result: UpToDate\n".to_string() + } + rpki::rtr::cache::SerialResult::ResetRequired => { + " result: ResetRequired\n".to_string() + } + rpki::rtr::cache::SerialResult::Deltas(deltas) => { + let mut out = String::new(); + let _ = writeln!(&mut out, " result: Deltas"); + for (idx, delta) in deltas.iter().enumerate() { + let _ = writeln!(&mut out, " delta[{}].serial: {}", idx, delta.serial()); + let _ = writeln!(&mut out, " delta[{}].announced:", idx); + out.push_str(&indent_block(&payloads_to_string(delta.announced()), 4)); + let _ = writeln!(&mut out, " delta[{}].withdrawn:", idx); + out.push_str(&indent_block(&payloads_to_string(delta.withdrawn()), 4)); + } + out + } + } +} + +pub fn indent_block(text: &str, spaces: usize) -> String { + let pad = " ".repeat(spaces); + let mut out = String::new(); + for line in text.lines() { + let _ = writeln!(&mut out, "{}{}", pad, line); + } + out +} \ No newline at end of file diff --git a/tests/test_cache.rs b/tests/test_cache.rs new file mode 100644 index 0000000..013c9f5 --- /dev/null +++ b/tests/test_cache.rs @@ -0,0 +1,743 @@ +mod common; + +use std::collections::VecDeque; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::Arc; + +use common::test_helper::{ + as_route_origin, as_v4_route_origin, indent_block, payloads_to_string, + serial_result_detail_to_string, snapshot_hashes_to_string, test_report, v4_origin, v6_origin, +}; + +use rpki::rtr::cache::{Delta, RtrCacheBuilder, SerialResult, Snapshot}; +use rpki::rtr::payload::{Payload, Timing}; +use rpki::rtr::store_db::RtrStore; + +fn delta_to_string(delta: &Delta) -> String { + format!( + "serial: {}\nannounced:\n{}withdrawn:\n{}", + delta.serial(), + indent_block(&payloads_to_string(delta.announced()), 2), + indent_block(&payloads_to_string(delta.withdrawn()), 2), + ) +} + +fn deltas_window_to_string(deltas: &VecDeque>) -> String { + if deltas.is_empty() { + return " \n".to_string(); + } + + let mut out = String::new(); + for (idx, delta) in deltas.iter().enumerate() { + out.push_str(&format!("delta[{}]:\n", idx)); + out.push_str(&indent_block(&delta_to_string(delta), 2)); + } + out +} + +fn get_deltas_since_input_to_string( + cache_session_id: u16, + cache_serial: u32, + client_session: u16, + client_serial: u32, +) -> String { + format!( + "cache.session_id: {}\ncache.serial: {}\nclient_session: {}\nclient_serial: {}\n", + cache_session_id, cache_serial, client_session, client_serial + ) +} + +fn snapshot_hashes_and_sorted_view_to_string(snapshot: &Snapshot) -> String { + let payloads = snapshot.payloads_for_rtr(); + format!( + "hashes:\n{}sorted payloads_for_rtr:\n{}", + indent_block(&snapshot_hashes_to_string(snapshot), 2), + indent_block(&payloads_to_string(&payloads), 2), + ) +} + +#[test] +fn snapshot_hash_is_stable_for_same_content_with_different_input_order() { + let a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + let c = v6_origin( + Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + 32, + 48, + 64498, + ); + + let s1_input = vec![ + Payload::RouteOrigin(a.clone()), + Payload::RouteOrigin(b.clone()), + Payload::RouteOrigin(c.clone()), + ]; + let s2_input = vec![ + Payload::RouteOrigin(c), + Payload::RouteOrigin(a), + Payload::RouteOrigin(b), + ]; + + let s1 = Snapshot::from_payloads(s1_input.clone()); + let s2 = Snapshot::from_payloads(s2_input.clone()); + + let input = format!( + "s1 原始输入 payloads:\n{}\ns2 原始输入 payloads:\n{}", + indent_block(&payloads_to_string(&s1_input), 2), + indent_block(&payloads_to_string(&s2_input), 2), + ); + + let output = format!( + "s1:\n{}\ns2:\n{}\n结论:\n same_content: {}\n same_origins: {}\n snapshot_hash 相同: {}\n origins_hash 相同: {}\n", + indent_block(&snapshot_hashes_and_sorted_view_to_string(&s1), 2), + indent_block(&snapshot_hashes_and_sorted_view_to_string(&s2), 2), + s1.same_content(&s2), + s1.same_origins(&s2), + s1.snapshot_hash() == s2.snapshot_hash(), + s1.origins_hash() == s2.origins_hash(), + ); + + test_report( + "snapshot_hash_is_stable_for_same_content_with_different_input_order", + "验证相同语义内容即使原始输入顺序不同,Snapshot 的 hash 仍然稳定一致。", + &input, + &output, + ); + + assert!(s1.same_content(&s2)); + assert!(s1.same_origins(&s2)); + assert_eq!(s1.snapshot_hash(), s2.snapshot_hash()); + assert_eq!(s1.origins_hash(), s2.origins_hash()); +} + +#[test] +fn snapshot_diff_reports_announced_and_withdrawn_correctly() { + let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + let new_c = v6_origin( + Ipv6Addr::new(0x2001, 0xdb8, 0, 1, 0, 0, 0, 0), + 48, + 48, + 64499, + ); + + let old_input = vec![ + Payload::RouteOrigin(old_a.clone()), + Payload::RouteOrigin(old_b.clone()), + ]; + let new_input = vec![ + Payload::RouteOrigin(old_b), + Payload::RouteOrigin(new_c.clone()), + ]; + + let old_snapshot = Snapshot::from_payloads(old_input.clone()); + let new_snapshot = Snapshot::from_payloads(new_input.clone()); + + let (announced, withdrawn) = old_snapshot.diff(&new_snapshot); + + let input = format!( + "old_snapshot 原始输入:\n{}\nnew_snapshot 原始输入:\n{}", + indent_block(&payloads_to_string(&old_input), 2), + indent_block(&payloads_to_string(&new_input), 2), + ); + + let output = format!( + "announced:\n{}withdrawn:\n{}", + indent_block(&payloads_to_string(&announced), 2), + indent_block(&payloads_to_string(&withdrawn), 2), + ); + + test_report( + "snapshot_diff_reports_announced_and_withdrawn_correctly", + "验证 diff() 能正确找出 announced 和 withdrawn 的 payload。", + &input, + &output, + ); + + assert_eq!(announced.len(), 1); + assert_eq!(withdrawn.len(), 1); + + match &announced[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &new_c), + _ => panic!("expected announced RouteOrigin"), + } + + match &withdrawn[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &old_a), + _ => panic!("expected withdrawn RouteOrigin"), + } +} + +#[test] +fn snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_descending() { + let v4_low = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let v4_high = v4_origin(198, 51, 100, 0, 24, 24, 64497); + let v6 = v6_origin( + Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + 32, + 48, + 64498, + ); + + let input_payloads = vec![ + Payload::RouteOrigin(v6.clone()), + Payload::RouteOrigin(v4_low.clone()), + Payload::RouteOrigin(v4_high.clone()), + ]; + + let snapshot = Snapshot::from_payloads(input_payloads.clone()); + let output_payloads = snapshot.payloads_for_rtr(); + + let input = format!( + "原始输入 payloads(构造 Snapshot 前):\n{}", + indent_block(&payloads_to_string(&input_payloads), 2), + ); + + let output = format!( + "排序后 payloads_for_rtr:\n{}", + indent_block(&payloads_to_string(&output_payloads), 2), + ); + + test_report( + "snapshot_payloads_for_rtr_sorts_ipv4_before_ipv6_and_ipv4_announcements_descending", + "验证 Snapshot::payloads_for_rtr() 会按 RTR 规则排序:IPv4 在 IPv6 前,且 IPv4 announcement 按地址降序。", + &input, + &output, + ); + + assert_eq!(output_payloads.len(), 3); + + let first = as_v4_route_origin(&output_payloads[0]); + let second = as_v4_route_origin(&output_payloads[1]); + + assert_eq!( + first.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(198, 51, 100, 0)) + ); + assert_eq!( + second.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(192, 0, 2, 0)) + ); + + let third = as_route_origin(&output_payloads[2]); + assert!(third.prefix().address.is_ipv6()); + assert_eq!( + third.prefix().address.to_ipv6(), + Some(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)) + ); +} + +#[test] +fn delta_new_sorts_announced_descending_and_withdrawn_ascending() { + let announced_low = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let announced_high = v4_origin(198, 51, 100, 0, 24, 24, 64497); + let withdrawn_high = v4_origin(203, 0, 113, 0, 24, 24, 64501); + let withdrawn_low = v4_origin(10, 0, 0, 0, 24, 24, 64500); + + let input_announced = vec![ + Payload::RouteOrigin(announced_low), + Payload::RouteOrigin(announced_high), + ]; + let input_withdrawn = vec![ + Payload::RouteOrigin(withdrawn_high), + Payload::RouteOrigin(withdrawn_low), + ]; + + let delta = Delta::new(101, input_announced.clone(), input_withdrawn.clone()); + + let input = format!( + "announced(构造前):\n{}withdrawn(构造前):\n{}", + indent_block(&payloads_to_string(&input_announced), 2), + indent_block(&payloads_to_string(&input_withdrawn), 2), + ); + + let output = indent_block(&delta_to_string(&delta), 2); + + test_report( + "delta_new_sorts_announced_descending_and_withdrawn_ascending", + "验证 Delta::new() 会自动排序:announced 按 RTR announcement 规则,withdrawn 按 RTR withdrawal 规则。", + &input, + &output, + ); + + assert_eq!(delta.serial(), 101); + assert_eq!(delta.announced().len(), 2); + assert_eq!(delta.withdrawn().len(), 2); + + let a0 = as_v4_route_origin(&delta.announced()[0]); + let a1 = as_v4_route_origin(&delta.announced()[1]); + assert_eq!( + a0.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(198, 51, 100, 0)) + ); + assert_eq!( + a1.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(192, 0, 2, 0)) + ); + + let w0 = as_v4_route_origin(&delta.withdrawn()[0]); + let w1 = as_v4_route_origin(&delta.withdrawn()[1]); + assert_eq!(w0.prefix().address.to_ipv4(), Some(Ipv4Addr::new(10, 0, 0, 0))); + assert_eq!( + w1.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(203, 0, 113, 0)) + ); +} + +#[test] +fn get_deltas_since_returns_up_to_date_when_client_serial_matches_current() { + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::default()) + .build(); + + let result = cache.get_deltas_since(42, 100); + + let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 100); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_returns_up_to_date_when_client_serial_matches_current", + "验证当客户端 serial 与缓存当前 serial 相同,返回 UpToDate。", + &input, + &output, + ); + + match result { + SerialResult::UpToDate => {} + _ => panic!("expected UpToDate"), + } +} + +#[test] +fn get_deltas_since_returns_reset_required_on_session_mismatch() { + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::default()) + .build(); + + let result = cache.get_deltas_since(999, 100); + + let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 999, 100); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_returns_reset_required_on_session_mismatch", + "验证当客户端 session_id 与缓存 session_id 不一致时,返回 ResetRequired。", + &input, + &output, + ); + + match result { + SerialResult::ResetRequired => {} + _ => panic!("expected ResetRequired"), + } +} + +#[test] +fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old() { + let d1 = Arc::new(Delta::new( + 101, + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + vec![], + )); + let d2 = Arc::new(Delta::new( + 102, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d1); + deltas.push_back(d2); + + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(102) + .timing(Timing::default()) + .deltas(deltas.clone()) + .build(); + + let result = cache.get_deltas_since(42, 99); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 99), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_returns_reset_required_when_client_serial_is_too_old", + "验证当客户端 serial 太旧,已超出 delta window 覆盖范围时,返回 ResetRequired。", + &input, + &output, + ); + + match result { + SerialResult::ResetRequired => {} + _ => panic!("expected ResetRequired"), + } +} + +#[test] +fn get_deltas_since_returns_applicable_deltas_in_order() { + let d1 = Arc::new(Delta::new( + 101, + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + vec![], + )); + let d2 = Arc::new(Delta::new( + 102, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + )); + let d3 = Arc::new(Delta::new( + 103, + vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(d1); + deltas.push_back(d2); + deltas.push_back(d3); + + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(103) + .timing(Timing::default()) + .deltas(deltas.clone()) + .build(); + + let result = cache.get_deltas_since(42, 101); + + let input = format!( + "{}delta_window:\n{}", + get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 101), + indent_block(&deltas_window_to_string(&deltas), 2), + ); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_returns_applicable_deltas_in_order", + "验证当客户端 serial 在 delta window 内时,返回正确且有序的 deltas。", + &input, + &output, + ); + + match result { + SerialResult::Deltas(result) => { + assert_eq!(result.len(), 2); + assert_eq!(result[0].serial(), 102); + assert_eq!(result[1].serial(), 103); + } + _ => panic!("expected Deltas"), + } +} + +#[test] +fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future() { + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::default()) + .build(); + + let result = cache.get_deltas_since(42, 101); + + let input = get_deltas_since_input_to_string(cache.session_id(), cache.serial(), 42, 101); + let output = serial_result_detail_to_string(&result); + + test_report( + "get_deltas_since_returns_reset_required_when_client_serial_is_in_future", + "验证当客户端 serial 比缓存当前 serial 还大时,返回 ResetRequired。", + &input, + &output, + ); + + match result { + SerialResult::ResetRequired => {} + _ => panic!("expected ResetRequired"), + } +} + +#[tokio::test] +async fn update_no_change_keeps_serial_and_produces_no_delta() { + let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + + let old_input = vec![ + Payload::RouteOrigin(old_a.clone()), + Payload::RouteOrigin(old_b.clone()), + ]; + let snapshot = Snapshot::from_payloads(old_input.clone()); + + let mut cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::default()) + .snapshot(snapshot.clone()) + .build(); + + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let new_payloads = vec![ + Payload::RouteOrigin(old_b), + Payload::RouteOrigin(old_a), + ]; + + cache.update(new_payloads.clone(), &store).unwrap(); + + let current_snapshot = cache.snapshot(); + let result = cache.get_deltas_since(42, 100); + + let input = format!( + "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + indent_block(&payloads_to_string(&old_input), 2), + indent_block(&payloads_to_string(&new_payloads), 2), + ); + + let output = format!( + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + cache.serial(), + indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block(&serial_result_detail_to_string(&result), 2), + ); + + test_report( + "update_no_change_keeps_serial_and_produces_no_delta", + "验证 update() 在新旧内容完全相同时:serial 不变、snapshot 不变、不会产生新的 delta。", + &input, + &output, + ); + + assert_eq!(cache.serial(), 100); + assert!(cache.snapshot().same_content(&snapshot)); + + match result { + SerialResult::UpToDate => {} + _ => panic!("expected UpToDate"), + } +} + +#[tokio::test] +async fn update_add_only_increments_serial_and_generates_announced_delta() { + let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let new_b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + + let old_input = vec![Payload::RouteOrigin(old_a.clone())]; + let old_snapshot = Snapshot::from_payloads(old_input.clone()); + + let mut cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::default()) + .snapshot(old_snapshot.clone()) + .build(); + + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let new_payloads = vec![ + Payload::RouteOrigin(old_a.clone()), + Payload::RouteOrigin(new_b.clone()), + ]; + + cache.update(new_payloads.clone(), &store).unwrap(); + + let current_snapshot = cache.snapshot(); + let result = cache.get_deltas_since(42, 100); + + let input = format!( + "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + indent_block(&payloads_to_string(&old_input), 2), + indent_block(&payloads_to_string(&new_payloads), 2), + ); + + let output = format!( + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + cache.serial(), + indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block(&serial_result_detail_to_string(&result), 2), + ); + + test_report( + "update_add_only_increments_serial_and_generates_announced_delta", + "验证 update() 在只新增 payload 时:serial + 1,delta 中只有 announced,withdrawn 为空。", + &input, + &output, + ); + + assert_eq!(cache.serial(), 101); + + match result { + SerialResult::Deltas(deltas) => { + assert_eq!(deltas.len(), 1); + let delta = &deltas[0]; + + assert_eq!(delta.serial(), 101); + assert_eq!(delta.announced().len(), 1); + assert_eq!(delta.withdrawn().len(), 0); + + match &delta.announced()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &new_b), + _ => panic!("expected announced RouteOrigin"), + } + } + _ => panic!("expected Deltas"), + } +} + +#[tokio::test] +async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { + let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + + let old_input = vec![ + Payload::RouteOrigin(old_a.clone()), + Payload::RouteOrigin(old_b.clone()), + ]; + let old_snapshot = Snapshot::from_payloads(old_input.clone()); + + let mut cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::default()) + .snapshot(old_snapshot.clone()) + .build(); + + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let new_payloads = vec![Payload::RouteOrigin(old_b.clone())]; + + cache.update(new_payloads.clone(), &store).unwrap(); + + let current_snapshot = cache.snapshot(); + let result = cache.get_deltas_since(42, 100); + + let input = format!( + "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + indent_block(&payloads_to_string(&old_input), 2), + indent_block(&payloads_to_string(&new_payloads), 2), + ); + + let output = format!( + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + cache.serial(), + indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block(&serial_result_detail_to_string(&result), 2), + ); + + test_report( + "update_remove_only_increments_serial_and_generates_withdrawn_delta", + "验证 update() 在只删除 payload 时:serial + 1,delta 中只有 withdrawn,announced 为空。", + &input, + &output, + ); + + assert_eq!(cache.serial(), 101); + + match result { + SerialResult::Deltas(deltas) => { + assert_eq!(deltas.len(), 1); + let delta = &deltas[0]; + + assert_eq!(delta.serial(), 101); + assert_eq!(delta.announced().len(), 0); + assert_eq!(delta.withdrawn().len(), 1); + + match &delta.withdrawn()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &old_a), + _ => panic!("expected withdrawn RouteOrigin"), + } + } + _ => panic!("expected Deltas"), + } +} + +#[tokio::test] +async fn update_add_and_remove_increments_serial_and_generates_both_sides() { + let old_a = v4_origin(192, 0, 2, 0, 24, 24, 64496); + let old_b = v4_origin(198, 51, 100, 0, 24, 24, 64497); + let new_c = v6_origin( + Ipv6Addr::new(0x2001, 0xdb8, 0, 1, 0, 0, 0, 0), + 48, + 48, + 64499, + ); + + let old_input = vec![ + Payload::RouteOrigin(old_a.clone()), + Payload::RouteOrigin(old_b.clone()), + ]; + let old_snapshot = Snapshot::from_payloads(old_input.clone()); + + let mut cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::default()) + .snapshot(old_snapshot.clone()) + .build(); + + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let new_payloads = vec![ + Payload::RouteOrigin(old_b.clone()), + Payload::RouteOrigin(new_c.clone()), + ]; + + cache.update(new_payloads.clone(), &store).unwrap(); + + let current_snapshot = cache.snapshot(); + let result = cache.get_deltas_since(42, 100); + + let input = format!( + "old_snapshot 原始输入:\n{}new_payloads 原始输入:\n{}", + indent_block(&payloads_to_string(&old_input), 2), + indent_block(&payloads_to_string(&new_payloads), 2), + ); + + let output = format!( + "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(42, 100):\n{}", + cache.serial(), + indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block(&serial_result_detail_to_string(&result), 2), + ); + + test_report( + "update_add_and_remove_increments_serial_and_generates_both_sides", + "验证 update() 在同时新增和删除 payload 时:serial + 1,delta 中 announced 和 withdrawn 都正确。", + &input, + &output, + ); + + assert_eq!(cache.serial(), 101); + + match result { + SerialResult::Deltas(deltas) => { + assert_eq!(deltas.len(), 1); + let delta = &deltas[0]; + + assert_eq!(delta.serial(), 101); + assert_eq!(delta.announced().len(), 1); + assert_eq!(delta.withdrawn().len(), 1); + + match &delta.announced()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &new_c), + _ => panic!("expected announced RouteOrigin"), + } + + match &delta.withdrawn()[0] { + Payload::RouteOrigin(ro) => assert_eq!(ro, &old_a), + _ => panic!("expected withdrawn RouteOrigin"), + } + } + _ => panic!("expected Deltas"), + } +} \ No newline at end of file diff --git a/tests/test_session.rs b/tests/test_session.rs new file mode 100644 index 0000000..e7ed8e8 --- /dev/null +++ b/tests/test_session.rs @@ -0,0 +1,492 @@ +mod common; + +use std::collections::VecDeque; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::{Arc, RwLock}; + +use tokio::net::TcpListener; +use tokio::sync::{broadcast, watch}; + +use common::test_helper::{ + dump_cache_reset, dump_cache_response, dump_eod_v1, dump_ipv4_prefix, dump_ipv6_prefix, + RtrDebugDumper, +}; + +use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; +use rpki::rtr::cache::{Delta, SharedRtrCache, RtrCacheBuilder, Snapshot}; +use rpki::rtr::payload::{Payload, RouteOrigin, Timing}; +use rpki::rtr::pdu::{ + CacheResponse, CacheReset, EndOfDataV1, IPv4Prefix, IPv6Prefix, ResetQuery, SerialQuery, +}; +use rpki::rtr::session::RtrSession; + +fn shared_cache(cache: rpki::rtr::cache::RtrCache) -> SharedRtrCache { + Arc::new(RwLock::new(cache)) +} + +#[tokio::test] +async fn reset_query_returns_snapshot_and_end_of_data() { + let prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)), + prefix_length: 24, + }; + let origin = RouteOrigin::new(prefix, 24, 64496u32.into()); + + let snapshot = Snapshot::from_payloads(vec![Payload::RouteOrigin(origin)]); + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .snapshot(snapshot) + .build(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_cache = shared_cache(cache); + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); + session.run().await.unwrap(); + }); + + let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.pdu(), 3); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let prefix = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(prefix.pdu(), dump_ipv4_prefix(&prefix)); + assert_eq!(prefix.pdu(), 4); + assert_eq!(prefix.version(), 1); + assert!(prefix.flag().is_announce()); + assert_eq!(prefix.prefix_len(), 24); + assert_eq!(prefix.max_len(), 24); + assert_eq!(prefix.prefix(), Ipv4Addr::new(192, 0, 2, 0)); + assert_eq!(prefix.asn(), 64496u32.into()); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.pdu(), 7); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + let timing = eod.timing(); + assert_eq!(timing.refresh, 600); + assert_eq!(timing.retry, 600); + assert_eq!(timing.expire, 7200); + + dump.print_pretty("reset_query_returns_snapshot_and_end_of_data"); +} + +#[tokio::test] +async fn serial_query_returns_end_of_data_when_up_to_date() { + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing { + refresh: 600, + retry: 600, + expire: 7200, + }) + .build(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_cache = shared_cache(cache); + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); + session.run().await.unwrap(); + }); + + let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.pdu(), 7); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + let timing = eod.timing(); + assert_eq!(timing.refresh, 600); + assert_eq!(timing.retry, 600); + assert_eq!(timing.expire, 7200); + + dump.print_pretty("serial_query_returns_end_of_data_when_up_to_date"); +} + +#[tokio::test] +async fn serial_query_returns_cache_reset_when_session_id_mismatch() { + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing { + refresh: 600, + retry: 600, + expire: 7200, + }) + .build(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_cache = shared_cache(cache); + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); + session.run().await.unwrap(); + }); + + let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + SerialQuery::new(1, 999, 100).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let reset = CacheReset::read(&mut client).await.unwrap(); + dump.push_value(reset.pdu(), dump_cache_reset(reset.version(), reset.pdu())); + assert_eq!(reset.pdu(), 8); + assert_eq!(reset.version(), 1); + + dump.print_pretty("serial_query_returns_cache_reset_when_session_id_mismatch"); +} + +#[tokio::test] +async fn serial_query_returns_deltas_when_incremental_update_available() { + let prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)), + prefix_length: 24, + }; + let origin = RouteOrigin::new(prefix, 24, 64496u32.into()); + + let delta = Arc::new(Delta::new( + 101, + vec![Payload::RouteOrigin(origin)], + vec![], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(delta); + + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(101) + .timing(Timing { + refresh: 600, + retry: 600, + expire: 7200, + }) + .deltas(deltas) + .build(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_cache = shared_cache(cache); + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); + session.run().await.unwrap(); + }); + + let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.pdu(), 3); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let prefix = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(prefix.pdu(), dump_ipv4_prefix(&prefix)); + assert_eq!(prefix.pdu(), 4); + assert_eq!(prefix.version(), 1); + assert!(prefix.flag().is_announce()); + assert_eq!(prefix.prefix_len(), 24); + assert_eq!(prefix.max_len(), 24); + assert_eq!(prefix.prefix(), Ipv4Addr::new(192, 0, 2, 0)); + assert_eq!(prefix.asn(), 64496u32.into()); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.pdu(), 7); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 101); + + let timing = eod.timing(); + assert_eq!(timing.refresh, 600); + assert_eq!(timing.retry, 600); + assert_eq!(timing.expire, 7200); + + dump.print_pretty("serial_query_returns_deltas_when_incremental_update_available"); +} + +#[tokio::test] +async fn reset_query_returns_payloads_in_rtr_order() { + let v4_low_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)), + prefix_length: 24, + }; + let v4_low_origin = RouteOrigin::new(v4_low_prefix, 24, 64496u32.into()); + + let v4_high_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(198, 51, 100, 0)), + prefix_length: 24, + }; + let v4_high_origin = RouteOrigin::new(v4_high_prefix, 24, 64497u32.into()); + + let v6_prefix = IPAddressPrefix { + address: IPAddress::from_ipv6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)), + prefix_length: 32, + }; + let v6_origin = RouteOrigin::new(v6_prefix, 48, 64498u32.into()); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v6_origin), + Payload::RouteOrigin(v4_low_origin), + Payload::RouteOrigin(v4_high_origin), + ]); + + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(100) + .timing(Timing::new(600, 600, 7200)) + .snapshot(snapshot) + .build(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_cache = shared_cache(cache); + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); + session.run().await.unwrap(); + }); + + let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + ResetQuery::new(1).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.pdu(), 3); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let first = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(first.pdu(), dump_ipv4_prefix(&first)); + assert_eq!(first.pdu(), 4); + assert_eq!(first.version(), 1); + assert!(first.flag().is_announce()); + assert_eq!(first.prefix(), Ipv4Addr::new(198, 51, 100, 0)); + assert_eq!(first.prefix_len(), 24); + assert_eq!(first.max_len(), 24); + assert_eq!(first.asn(), 64497u32.into()); + + let second = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(second.pdu(), dump_ipv4_prefix(&second)); + assert_eq!(second.pdu(), 4); + assert_eq!(second.version(), 1); + assert!(second.flag().is_announce()); + assert_eq!(second.prefix(), Ipv4Addr::new(192, 0, 2, 0)); + assert_eq!(second.prefix_len(), 24); + assert_eq!(second.max_len(), 24); + assert_eq!(second.asn(), 64496u32.into()); + + assert!(u32::from(first.prefix()) > u32::from(second.prefix())); + + let third = IPv6Prefix::read(&mut client).await.unwrap(); + dump.push_value(third.pdu(), dump_ipv6_prefix(&third)); + assert_eq!(third.pdu(), 6); + assert_eq!(third.version(), 1); + assert!(third.flag().is_announce()); + assert_eq!(third.prefix(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)); + assert_eq!(third.prefix_len(), 32); + assert_eq!(third.max_len(), 48); + assert_eq!(third.asn(), 64498u32.into()); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.pdu(), 7); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 100); + + let timing = eod.timing(); + assert_eq!(timing.refresh, 600); + assert_eq!(timing.retry, 600); + assert_eq!(timing.expire, 7200); + + dump.print_pretty("reset_query_returns_payloads_in_rtr_order"); +} + +#[tokio::test] +async fn serial_query_returns_announcements_before_withdrawals() { + let announced_low_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(192, 0, 2, 0)), + prefix_length: 24, + }; + let announced_low_origin = RouteOrigin::new(announced_low_prefix, 24, 64496u32.into()); + + let announced_high_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(198, 51, 100, 0)), + prefix_length: 24, + }; + let announced_high_origin = RouteOrigin::new(announced_high_prefix, 24, 64497u32.into()); + + let withdrawn_low_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(10, 0, 0, 0)), + prefix_length: 24, + }; + let withdrawn_low_origin = RouteOrigin::new(withdrawn_low_prefix, 24, 64500u32.into()); + + let withdrawn_high_prefix = IPAddressPrefix { + address: IPAddress::from_ipv4(Ipv4Addr::new(203, 0, 113, 0)), + prefix_length: 24, + }; + let withdrawn_high_origin = RouteOrigin::new(withdrawn_high_prefix, 24, 64501u32.into()); + + let delta = Arc::new(Delta::new( + 101, + vec![ + Payload::RouteOrigin(announced_low_origin), + Payload::RouteOrigin(announced_high_origin), + ], + vec![ + Payload::RouteOrigin(withdrawn_high_origin), + Payload::RouteOrigin(withdrawn_low_origin), + ], + )); + + let mut deltas = VecDeque::new(); + deltas.push_back(delta); + + let cache = RtrCacheBuilder::new() + .session_id(42) + .serial(101) + .timing(Timing { + refresh: 600, + retry: 600, + expire: 7200, + }) + .deltas(deltas) + .build(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_cache = shared_cache(cache); + let (_notify_tx, notify_rx) = broadcast::channel(16); + let (_shutdown_tx, shutdown_rx) = watch::channel(false); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let session = RtrSession::new(server_cache, stream, notify_rx, shutdown_rx); + session.run().await.unwrap(); + }); + + let mut client = tokio::net::TcpStream::connect(addr).await.unwrap(); + SerialQuery::new(1, 42, 100).write(&mut client).await.unwrap(); + + let mut dump = RtrDebugDumper::new(); + + let response = CacheResponse::read(&mut client).await.unwrap(); + dump.push_value(response.pdu(), dump_cache_response(&response)); + assert_eq!(response.pdu(), 3); + assert_eq!(response.version(), 1); + assert_eq!(response.session_id(), 42); + + let first = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(first.pdu(), dump_ipv4_prefix(&first)); + assert_eq!(first.pdu(), 4); + assert_eq!(first.version(), 1); + assert!(first.flag().is_announce()); + assert_eq!(first.prefix(), Ipv4Addr::new(198, 51, 100, 0)); + assert_eq!(first.prefix_len(), 24); + assert_eq!(first.max_len(), 24); + assert_eq!(first.asn(), 64497u32.into()); + + let second = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(second.pdu(), dump_ipv4_prefix(&second)); + assert_eq!(second.pdu(), 4); + assert_eq!(second.version(), 1); + assert!(second.flag().is_announce()); + assert_eq!(second.prefix(), Ipv4Addr::new(192, 0, 2, 0)); + assert_eq!(second.prefix_len(), 24); + assert_eq!(second.max_len(), 24); + assert_eq!(second.asn(), 64496u32.into()); + + assert!(u32::from(first.prefix()) > u32::from(second.prefix())); + + let third = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(third.pdu(), dump_ipv4_prefix(&third)); + assert_eq!(third.pdu(), 4); + assert_eq!(third.version(), 1); + assert!(!third.flag().is_announce()); + assert_eq!(third.prefix(), Ipv4Addr::new(10, 0, 0, 0)); + assert_eq!(third.prefix_len(), 24); + assert_eq!(third.max_len(), 24); + assert_eq!(third.asn(), 64500u32.into()); + + let fourth = IPv4Prefix::read(&mut client).await.unwrap(); + dump.push_value(fourth.pdu(), dump_ipv4_prefix(&fourth)); + assert_eq!(fourth.pdu(), 4); + assert_eq!(fourth.version(), 1); + assert!(!fourth.flag().is_announce()); + assert_eq!(fourth.prefix(), Ipv4Addr::new(203, 0, 113, 0)); + assert_eq!(fourth.prefix_len(), 24); + assert_eq!(fourth.max_len(), 24); + assert_eq!(fourth.asn(), 64501u32.into()); + + assert!(u32::from(third.prefix()) < u32::from(fourth.prefix())); + assert!(first.flag().is_announce()); + assert!(second.flag().is_announce()); + assert!(!third.flag().is_announce()); + assert!(!fourth.flag().is_announce()); + + let eod = EndOfDataV1::read(&mut client).await.unwrap(); + dump.push_value(eod.pdu(), dump_eod_v1(&eod)); + assert_eq!(eod.pdu(), 7); + assert_eq!(eod.version(), 1); + assert_eq!(eod.session_id(), 42); + assert_eq!(eod.serial_number(), 101); + + let timing = eod.timing(); + assert_eq!(timing.refresh, 600); + assert_eq!(timing.retry, 600); + assert_eq!(timing.expire, 7200); + + dump.print_pretty("serial_query_returns_announcements_before_withdrawals"); +} \ No newline at end of file diff --git a/tests/test_store_db.rs b/tests/test_store_db.rs new file mode 100644 index 0000000..a1550d2 --- /dev/null +++ b/tests/test_store_db.rs @@ -0,0 +1,352 @@ +mod common; + +use std::net::Ipv6Addr; + +use common::test_helper::{ + indent_block, payloads_to_string, test_report, v4_origin, v6_origin, +}; + +use rpki::rtr::cache::{Delta, Snapshot}; +use rpki::rtr::payload::Payload; +use rpki::rtr::store_db::RtrStore; + +fn snapshot_to_string(snapshot: &Snapshot) -> String { + let payloads = snapshot.payloads_for_rtr(); + payloads_to_string(&payloads) +} + +fn delta_to_string(delta: &Delta) -> String { + format!( + "serial: {}\nannounced:\n{}withdrawn:\n{}", + delta.serial(), + indent_block(&payloads_to_string(delta.announced()), 2), + indent_block(&payloads_to_string(delta.withdrawn()), 2), + ) +} + +#[test] +fn store_db_save_and_get_snapshot() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let input_payloads = vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + Payload::RouteOrigin(v6_origin( + Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + 32, + 48, + 64497, + )), + ]; + let snapshot = Snapshot::from_payloads(input_payloads.clone()); + + store.save_snapshot(&snapshot).unwrap(); + let loaded = store.get_snapshot().unwrap().expect("snapshot should exist"); + + let input = format!( + "db_path: {}\nsnapshot:\n{}", + dir.path().display(), + indent_block(&payloads_to_string(&input_payloads), 2), + ); + + let output = format!( + "loaded snapshot:\n{}same_content: {}\n", + indent_block(&snapshot_to_string(&loaded), 2), + snapshot.same_content(&loaded), + ); + + test_report( + "store_db_save_and_get_snapshot", + "验证 save_snapshot() 后可以通过 get_snapshot() 正确读回 Snapshot。", + &input, + &output, + ); + + assert!(snapshot.same_content(&loaded)); +} + +#[test] +fn store_db_set_and_get_meta_fields() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + store.set_session_id(42).unwrap(); + store.set_serial(100).unwrap(); + store.set_delta_window(101, 110).unwrap(); + + let session_id = store.get_session_id().unwrap(); + let serial = store.get_serial().unwrap(); + let window = store.get_delta_window().unwrap(); + + let input = format!( + "db_path: {}\nset_session_id=42\nset_serial=100\nset_delta_window=(101, 110)\n", + dir.path().display(), + ); + + let output = format!( + "get_session_id: {:?}\nget_serial: {:?}\nget_delta_window: {:?}\n", + session_id, serial, window, + ); + + test_report( + "store_db_set_and_get_meta_fields", + "验证 session_id / serial / delta_window 能正确写入并读回。", + &input, + &output, + ); + + assert_eq!(session_id, Some(42)); + assert_eq!(serial, Some(100)); + assert_eq!(window, Some((101, 110))); +} + +#[test] +fn store_db_save_and_get_delta() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let delta = Delta::new( + 101, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + ); + + store.save_delta(&delta).unwrap(); + let loaded = store.get_delta(101).unwrap().expect("delta should exist"); + + let input = format!( + "db_path: {}\ndelta:\n{}", + dir.path().display(), + indent_block(&delta_to_string(&delta), 2), + ); + + let output = format!( + "loaded delta:\n{}", + indent_block(&delta_to_string(&loaded), 2), + ); + + test_report( + "store_db_save_and_get_delta", + "验证 save_delta() 后可以通过 get_delta(serial) 正确读回 Delta。", + &input, + &output, + ); + + assert_eq!(loaded.serial(), 101); + assert_eq!(loaded.announced().len(), 1); + assert_eq!(loaded.withdrawn().len(), 1); +} + +#[test] +fn store_db_load_deltas_since_returns_only_newer_deltas_in_order() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let d101 = Delta::new( + 101, + vec![Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496))], + vec![], + ); + let d102 = Delta::new( + 102, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + ); + let d103 = Delta::new( + 103, + vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![], + ); + + store.save_delta(&d101).unwrap(); + store.save_delta(&d102).unwrap(); + store.save_delta(&d103).unwrap(); + + let loaded = store.load_deltas_since(101).unwrap(); + + let input = format!( + "db_path: {}\nsaved delta serials: [101, 102, 103]\nload_deltas_since(101)\n", + dir.path().display(), + ); + + let output = { + let mut s = String::new(); + for (idx, d) in loaded.iter().enumerate() { + s.push_str(&format!("loaded[{}]:\n", idx)); + s.push_str(&indent_block(&delta_to_string(d), 2)); + } + s + }; + + test_report( + "store_db_load_deltas_since_returns_only_newer_deltas_in_order", + "验证 load_deltas_since(x) 只返回 serial > x 的 Delta,且顺序正确。", + &input, + &output, + ); + + assert_eq!(loaded.len(), 2); + assert_eq!(loaded[0].serial(), 102); + assert_eq!(loaded[1].serial(), 103); +} + +#[test] +fn store_db_save_snapshot_and_meta_writes_all_fields() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)), + ]); + + store.save_snapshot_and_meta(&snapshot, 42, 100).unwrap(); + + let loaded_snapshot = store.get_snapshot().unwrap().expect("snapshot should exist"); + let loaded_session = store.get_session_id().unwrap(); + let loaded_serial = store.get_serial().unwrap(); + + let input = format!( + "db_path: {}\nsnapshot:\n{}session_id=42\nserial=100\n", + dir.path().display(), + indent_block(&snapshot_to_string(&snapshot), 2), + ); + + let output = format!( + "loaded_snapshot:\n{}loaded_session_id: {:?}\nloaded_serial: {:?}\n", + indent_block(&snapshot_to_string(&loaded_snapshot), 2), + loaded_session, + loaded_serial, + ); + + test_report( + "store_db_save_snapshot_and_meta_writes_all_fields", + "验证 save_snapshot_and_meta() 会同时写入 snapshot、session_id 和 serial。", + &input, + &output, + ); + + assert!(snapshot.same_content(&loaded_snapshot)); + assert_eq!(loaded_session, Some(42)); + assert_eq!(loaded_serial, Some(100)); +} + +#[test] +fn store_db_load_snapshot_and_serial_returns_consistent_pair() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498)), + ]); + + store.save_snapshot_and_serial(&snapshot, 200).unwrap(); + + let loaded = store + .load_snapshot_and_serial() + .unwrap() + .expect("snapshot+serial should exist"); + + let input = format!( + "db_path: {}\nsnapshot:\n{}serial=200\n", + dir.path().display(), + indent_block(&snapshot_to_string(&snapshot), 2), + ); + + let output = format!( + "loaded_snapshot:\n{}loaded_serial: {}\n", + indent_block(&snapshot_to_string(&loaded.0), 2), + loaded.1, + ); + + test_report( + "store_db_load_snapshot_and_serial_returns_consistent_pair", + "验证 load_snapshot_and_serial() 能正确返回一致的 snapshot 与 serial。", + &input, + &output, + ); + + assert!(snapshot.same_content(&loaded.0)); + assert_eq!(loaded.1, 200); +} + +#[test] +fn store_db_delete_snapshot_delta_and_serial_removes_data() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + ]); + let delta = Delta::new( + 101, + vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![], + ); + + store.save_snapshot(&snapshot).unwrap(); + store.save_delta(&delta).unwrap(); + store.set_serial(100).unwrap(); + + store.delete_snapshot().unwrap(); + store.delete_delta(101).unwrap(); + store.delete_serial().unwrap(); + + let loaded_snapshot = store.get_snapshot().unwrap(); + let loaded_delta = store.get_delta(101).unwrap(); + let loaded_serial = store.get_serial().unwrap(); + + let input = format!( + "db_path: {}\nsave snapshot + delta(101) + serial(100), then delete all three.\n", + dir.path().display(), + ); + + let output = format!( + "get_snapshot: {:?}\nget_delta(101): {:?}\nget_serial: {:?}\n", + loaded_snapshot.as_ref().map(|_| "Some(snapshot)"), + loaded_delta.as_ref().map(|_| "Some(delta)"), + loaded_serial, + ); + + test_report( + "store_db_delete_snapshot_delta_and_serial_removes_data", + "验证 delete_snapshot()/delete_delta()/delete_serial() 后,对应数据不再可读。", + &input, + &output, + ); + + assert!(loaded_snapshot.is_none()); + assert!(loaded_delta.is_none()); + assert!(loaded_serial.is_none()); +} + +#[test] +fn store_db_load_snapshot_and_serial_errors_on_inconsistent_state() { + let dir = tempfile::tempdir().unwrap(); + let store = RtrStore::open(dir.path()).unwrap(); + + let snapshot = Snapshot::from_payloads(vec![ + Payload::RouteOrigin(v4_origin(192, 0, 2, 0, 24, 24, 64496)), + ]); + + store.save_snapshot(&snapshot).unwrap(); + // 故意不写 serial,制造不一致状态 + + let result = store.load_snapshot_and_serial(); + + let input = format!( + "db_path: {}\n仅保存 snapshot,不保存 serial。\n", + dir.path().display(), + ); + + let output = format!("load_snapshot_and_serial result: {:?}\n", result); + + test_report( + "store_db_load_snapshot_and_serial_errors_on_inconsistent_state", + "验证当 snapshot 和 serial 状态不一致时,load_snapshot_and_serial() 返回错误。", + &input, + &output, + ); + + assert!(result.is_err()); +} \ No newline at end of file