rpki/tests/test_pdu.rs
2026-03-25 10:08:40 +08:00

205 lines
6.6 KiB
Rust

use std::net::Ipv4Addr;
use tokio::io::{duplex, AsyncWriteExt};
use rpki::data_model::resources::as_resources::Asn;
use rpki::rtr::error_type::ErrorCode;
use rpki::rtr::payload::{Aspa as PayloadAspa, Ski, Timing};
use rpki::rtr::pdu::{
Aspa, EndOfDataV1, ErrorReport, Flags, Header, IPv4Prefix, RouterKey, SerialNotify,
END_OF_DATA_V1_LEN, MAX_PDU_LEN,
};
const ERROR_REPORT_FIXED_PART_LEN: usize = 16;
#[tokio::test]
async fn serial_notify_roundtrip() {
let (mut client, mut server) = duplex(1024);
let original = SerialNotify::new(1, 42, 100);
tokio::spawn(async move {
original.write(&mut client).await.unwrap();
});
let decoded = SerialNotify::read(&mut server).await.unwrap();
assert_eq!(decoded.version(), 1);
assert_eq!(decoded.session_id(), 42);
assert_eq!(decoded.serial_number(), 100);
}
#[tokio::test]
async fn ipv4_prefix_roundtrip() {
let (mut client, mut server) = duplex(1024);
let prefix = IPv4Prefix::new(
1,
Flags::new(1),
24,
24,
Ipv4Addr::new(192, 168, 0, 0),
65000u32.into(),
);
tokio::spawn(async move {
prefix.write(&mut client).await.unwrap();
});
let decoded = IPv4Prefix::read(&mut server).await.unwrap();
assert_eq!(decoded.prefix_len(), 24);
assert_eq!(decoded.max_len(), 24);
assert_eq!(decoded.prefix(), Ipv4Addr::new(192, 168, 0, 0));
assert!(decoded.flag().is_announce());
}
#[test]
fn error_report_truncates_large_erroneous_pdu() {
let pdu = vec![0xAA; MAX_PDU_LEN as usize];
let text = b"details";
let report = ErrorReport::new(1, ErrorCode::CorruptData.as_u16(), &pdu, text);
assert_eq!(report.as_ref().len(), MAX_PDU_LEN as usize);
assert_eq!(
report.erroneous_pdu(),
&pdu[..(MAX_PDU_LEN as usize - ERROR_REPORT_FIXED_PART_LEN)]
);
assert!(report.text().is_empty());
}
#[test]
fn error_report_truncates_text_to_fit() {
let pdu = [1, 2, 3, 4];
let text = vec![b'x'; MAX_PDU_LEN as usize];
let report = ErrorReport::new(1, ErrorCode::CorruptData.as_u16(), pdu, &text);
assert_eq!(report.erroneous_pdu(), pdu);
assert_eq!(report.as_ref().len(), MAX_PDU_LEN as usize);
assert_eq!(
report.text().len(),
MAX_PDU_LEN as usize - ERROR_REPORT_FIXED_PART_LEN - pdu.len()
);
}
#[tokio::test]
async fn error_report_rejects_non_utf8_text() {
let (mut client, mut server) = duplex(1024);
let header = Header::new(1, ErrorReport::PDU, ErrorCode::CorruptData.as_u16(), 17);
let mut bytes = Vec::from(header.as_ref());
bytes.extend_from_slice(&0u32.to_be_bytes());
bytes.extend_from_slice(&1u32.to_be_bytes());
bytes.push(0xFF);
client.write_all(&bytes).await.unwrap();
let err = ErrorReport::read(&mut server).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn router_key_length_matches_wire_size() {
let ski = Ski::default();
let spki: std::sync::Arc<[u8]> = std::sync::Arc::from(vec![1u8; 32]);
let pdu = RouterKey::new(1, Flags::new(1), ski, Asn::from(64496u32), spki);
let (mut client, mut server) = duplex(1024);
tokio::spawn(async move {
pdu.write(&mut client).await.unwrap();
});
let header = Header::read(&mut server).await.unwrap();
assert_eq!(header.pdu(), RouterKey::PDU);
assert_eq!(header.length(), 8 + 20 + 4 + 32);
}
#[tokio::test]
async fn aspa_length_matches_wire_size() {
let pdu = Aspa::new(2, Flags::new(1), 64496, vec![64497, 64498]);
let (mut client, mut server) = duplex(1024);
tokio::spawn(async move {
pdu.write(&mut client).await.unwrap();
});
let header = Header::read(&mut server).await.unwrap();
assert_eq!(header.pdu(), Aspa::PDU);
assert_eq!(header.length(), 8 + 4 + 8);
}
#[test]
fn aspa_announcement_rejects_empty_provider_list() {
let err = PayloadAspa::new(Asn::from(64496u32), vec![])
.validate_announcement()
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("at least one provider"));
}
#[test]
fn aspa_announcement_rejects_as0() {
let err = PayloadAspa::new(Asn::from(0u32), vec![Asn::from(64497u32)])
.validate_announcement()
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("AS0"));
let err = PayloadAspa::new(Asn::from(64496u32), vec![Asn::from(0u32)])
.validate_announcement()
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("AS0"));
}
#[test]
fn timing_rejects_out_of_range_refresh() {
let err = Timing::new(0, 600, 7200).validate().unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("refresh interval"));
}
#[test]
fn timing_rejects_expire_not_greater_than_retry_and_refresh() {
let err = Timing::new(600, 600, 600).validate().unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("expire interval"));
}
#[test]
fn end_of_data_v1_rejects_invalid_timing() {
let err = EndOfDataV1::new(1, 42, 100, Timing::new(600, 8000, 7200)).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("retry interval"));
}
#[tokio::test]
async fn end_of_data_v1_read_rejects_invalid_timing() {
let (mut client, mut server) = duplex(1024);
let header = Header::new(1, EndOfDataV1::PDU, 42, END_OF_DATA_V1_LEN);
let mut bytes = Vec::from(header.as_ref());
bytes.extend_from_slice(&100u32.to_be_bytes());
bytes.extend_from_slice(&600u32.to_be_bytes());
bytes.extend_from_slice(&8000u32.to_be_bytes());
bytes.extend_from_slice(&7200u32.to_be_bytes());
client.write_all(&bytes).await.unwrap();
let err = EndOfDataV1::read(&mut server).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("retry interval"));
}
#[tokio::test]
async fn end_of_data_v1_read_payload_rejects_invalid_timing() {
let (mut client, mut server) = duplex(1024);
let header = Header::new(1, EndOfDataV1::PDU, 42, END_OF_DATA_V1_LEN);
client.write_all(&100u32.to_be_bytes()).await.unwrap();
client.write_all(&600u32.to_be_bytes()).await.unwrap();
client.write_all(&600u32.to_be_bytes()).await.unwrap();
client.write_all(&600u32.to_be_bytes()).await.unwrap();
let err = EndOfDataV1::read_payload(header, &mut server).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains("expire interval"));
}