rpki/src/rtr/pdu.rs
2026-04-01 16:24:01 +08:00

1295 lines
36 KiB
Rust

use crate::data_model::resources::as_resources::Asn;
use crate::rtr::error_type::ErrorCode;
use crate::rtr::payload::{Ski, Timing};
use anyhow::Result;
use std::io;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use std::{cmp, mem};
use tokio::io::AsyncWrite;
use anyhow::bail;
use serde::Serialize;
use std::slice;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
pub const HEADER_LEN: usize = 8;
pub const MAX_PDU_LEN: u32 = 65535;
pub const IPV4_PREFIX_LEN: u32 = 20;
pub const IPV6_PREFIX_LEN: u32 = 32;
pub const END_OF_DATA_V0_LEN: u32 = 12;
pub const END_OF_DATA_V1_LEN: u32 = 24;
pub const ZERO_16: u16 = 0;
pub const ZERO_8: u8 = 0;
macro_rules! common {
( $type:ident ) => {
#[allow(dead_code)]
impl $type {
/// Writes a value to a writer.
pub async fn write<A: AsyncWrite + Unpin>(&self, a: &mut A) -> Result<(), io::Error> {
a.write_all(self.as_ref()).await
}
}
impl AsRef<[u8]> for $type {
fn as_ref(&self) -> &[u8] {
unsafe {
slice::from_raw_parts(self as *const Self as *const u8, mem::size_of::<Self>())
}
}
}
impl AsMut<[u8]> for $type {
fn as_mut(&mut self) -> &mut [u8] {
unsafe {
slice::from_raw_parts_mut(self as *mut Self as *mut u8, mem::size_of::<Self>())
}
}
}
};
}
macro_rules! concrete {
( $type:ident ) => {
common!($type);
#[allow(dead_code)]
impl $type {
/// Returns the value of the version field of the header.
pub fn version(&self) -> u8 {
self.header.version()
}
/// Returns the value of the session field of the header.
///
/// Note that this field is used for other purposes in some PDU
/// types.
pub fn session_id(&self) -> u16 {
self.header.session_id()
}
pub fn pdu(&self) -> u8 {
self.header.pdu()
}
/// Returns the PDU size.
///
/// The size is returned as a `u32` since that type is used in
/// the header.
pub fn size() -> u32 {
mem::size_of::<Self>() as u32
}
/// Reads a value from a reader.
///
/// If a value with a different PDU type is received, returns an
/// error.
pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
concat!("PDU type mismatch when expecting ", stringify!($type)),
));
}
if res.header.length() as usize != res.as_ref().len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
concat!("invalid length for ", stringify!($type)),
));
}
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
Ok(res)
}
/// Tries to read a value from a reader.
///
/// If a different PDU type is received, returns the header as
/// the error case of the ok case.
pub async fn try_read<Sock: AsyncRead + Unpin>(
sock: &mut Sock,
) -> Result<Result<Self, Header>, io::Error> {
let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() == ErrorReport::PDU {
// Since we should drop the session after an error, we
// can safely ignore all the rest of the error for now.
return Ok(Err(res.header));
}
if res.header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
concat!("PDU type mismatch when expecting ", stringify!($type)),
));
}
if res.header.length() as usize != res.as_ref().len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
concat!("invalid length for ", stringify!($type)),
));
}
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
Ok(Ok(res))
}
/// Reads only the payload part of a value from a reader.
///
/// Assuming that the header was already read and is passed via
/// `header`, the function reads the rest of the PUD from the
/// reader and returns the complete value.
pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: Header,
sock: &mut Sock,
) -> Result<Self, io::Error> {
if header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
concat!("invalid length for ", stringify!($type), " PDU"),
));
}
let mut res = Self::default();
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
res.header = header;
Ok(res)
}
}
};
}
// 所有PDU公共头部信息
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Header {
version: u8,
pdu: u8,
session_id: u16,
length: u32,
}
impl Header {
const LEN: usize = mem::size_of::<Self>();
pub fn new(version: u8, pdu: u8, session: u16, length: u32) -> Self {
Header {
version,
pdu,
session_id: session.to_be(),
length: length.to_be(),
}
}
pub async fn read_raw<S: AsyncRead + Unpin>(
sock: &mut S,
) -> Result<[u8; HEADER_LEN], io::Error> {
let mut buf = [0u8; HEADER_LEN];
sock.read_exact(&mut buf).await?;
Ok(buf)
}
pub fn from_raw(buf: [u8; HEADER_LEN]) -> Result<Self, io::Error> {
let version = buf[0];
let pdu = buf[1];
let session_id = u16::from_be_bytes([buf[2], buf[3]]);
let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
if length < HEADER_LEN as u32 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid PDU length",
));
}
if length > MAX_PDU_LEN {
return Err(io::Error::new(io::ErrorKind::InvalidData, "PDU too large"));
}
Ok(Self {
version,
pdu,
session_id: session_id.to_be(),
length: length.to_be(),
})
}
pub async fn read<S: AsyncRead + Unpin>(sock: &mut S) -> Result<Self, io::Error> {
Self::from_raw(Self::read_raw(sock).await?)
}
pub fn version(self) -> u8 {
self.version
}
pub fn pdu(self) -> u8 {
self.pdu
}
pub fn session_id(self) -> u16 {
u16::from_be(self.session_id)
}
pub fn length(self) -> u32 {
u32::from_be(self.length)
}
pub fn pdu_len(self) -> Result<usize, io::Error> {
usize::try_from(self.length()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"PDU too large for this system to handle",
)
})
}
pub fn error_code(self) -> u16 {
debug_assert_eq!(self.pdu(), ErrorReport::PDU);
self.session_id()
}
}
common!(Header);
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct HeaderWithFlags {
version: u8,
pdu: u8,
flags: u8,
zero: u8,
length: u32,
}
impl HeaderWithFlags {
pub fn new(version: u8, pdu: u8, flags: Flags, length: u32) -> Self {
HeaderWithFlags {
version,
pdu,
flags: flags.into_u8(),
zero: ZERO_8,
length: length.to_be(),
}
}
pub async fn read<S: AsyncRead + Unpin>(sock: &mut S) -> Result<Self> {
let mut buf = [0u8; HEADER_LEN];
// 1. 精确读取 8 字节
sock.read_exact(&mut buf).await?;
// 2. 手动解析(大端)
let version = buf[0];
let pdu = buf[1];
let flags = buf[2];
let zero = buf[3];
let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
// 3. 基础合法性校验
if length < HEADER_LEN as u32 {
bail!("Invalid PDU length");
}
// 限制最大长度
if length > MAX_PDU_LEN {
bail!("PDU too large");
}
Ok(Self {
version,
pdu,
flags,
zero,
length: length.to_be(),
})
}
pub fn version(self) -> u8 {
self.version
}
pub fn pdu(self) -> u8 {
self.pdu
}
pub fn flags(self) -> Flags {
Flags(self.flags)
}
pub fn zero(self) -> u8 {
self.zero
}
pub fn length(self) -> u32 {
u32::from_be(self.length)
}
}
// Serial Notify
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct SerialNotify {
header: Header,
serial_number: u32,
}
impl SerialNotify {
pub const PDU: u8 = 0;
pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self {
SerialNotify {
header: Header::new(version, Self::PDU, session_id, Self::size()),
serial_number: serial_number.to_be(),
}
}
pub fn serial_number(self) -> u32 {
u32::from_be(self.serial_number)
}
}
concrete!(SerialNotify);
// Serial Query
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct SerialQuery {
header: Header,
serial_number: u32,
}
impl SerialQuery {
pub const PDU: u8 = 1;
pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self {
SerialQuery {
header: Header::new(version, Self::PDU, session_id, Self::size()),
serial_number: serial_number.to_be(),
}
}
pub fn serial_number(self) -> u32 {
u32::from_be(self.serial_number)
}
}
concrete!(SerialQuery);
// Reset Query
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct ResetQuery {
header: Header,
}
impl ResetQuery {
pub const PDU: u8 = 2;
pub fn new(version: u8) -> Self {
ResetQuery {
header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32),
}
}
}
concrete!(ResetQuery);
// Cache Response
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct CacheResponse {
header: Header,
}
impl CacheResponse {
pub const PDU: u8 = 3;
pub fn new(version: u8, session_id: u16) -> Self {
CacheResponse {
header: Header::new(version, Self::PDU, session_id, HEADER_LEN as u32),
}
}
}
concrete!(CacheResponse);
// Flags
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Flags(u8);
impl Flags {
pub fn new(raw: u8) -> Self {
Self(raw)
}
pub fn is_announce(self) -> bool {
self.0 & 0x01 == 1
}
pub fn is_withdraw(self) -> bool {
!self.is_announce()
}
pub fn into_u8(self) -> u8 {
self.0
}
}
// IPv4 Prefix
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct IPv4Prefix {
header: Header,
flags: Flags,
prefix_len: u8,
max_len: u8,
zero: u8,
prefix: u32,
asn: u32,
}
impl IPv4Prefix {
pub const PDU: u8 = 4;
pub fn new(
version: u8,
flags: Flags,
prefix_len: u8,
max_len: u8,
prefix: Ipv4Addr,
asn: Asn,
) -> Self {
IPv4Prefix {
header: Header::new(version, Self::PDU, ZERO_16, IPV4_PREFIX_LEN),
flags,
prefix_len,
max_len,
zero: ZERO_8,
prefix: u32::from(prefix).to_be(),
asn: asn.into_u32().to_be(),
}
}
pub fn flag(self) -> Flags {
self.flags
}
pub fn prefix_len(self) -> u8 {
self.prefix_len
}
pub fn max_len(self) -> u8 {
self.max_len
}
pub fn prefix(self) -> Ipv4Addr {
u32::from_be(self.prefix).into()
}
pub fn asn(self) -> Asn {
u32::from_be(self.asn).into()
}
}
concrete!(IPv4Prefix);
// IPv6 Prefix
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct IPv6Prefix {
header: Header,
flags: Flags,
prefix_len: u8,
max_len: u8,
zero: u8,
prefix: u128,
asn: u32,
}
impl IPv6Prefix {
pub const PDU: u8 = 6;
pub fn new(
version: u8,
flags: Flags,
prefix_len: u8,
max_len: u8,
prefix: Ipv6Addr,
asn: Asn,
) -> Self {
IPv6Prefix {
header: Header::new(version, Self::PDU, ZERO_16, IPV6_PREFIX_LEN),
flags,
prefix_len,
max_len,
zero: ZERO_8,
prefix: u128::from(prefix).to_be(),
asn: asn.into_u32().to_be(),
}
}
pub fn flag(self) -> Flags {
self.flags
}
pub fn prefix_len(self) -> u8 {
self.prefix_len
}
pub fn max_len(self) -> u8 {
self.max_len
}
pub fn prefix(self) -> Ipv6Addr {
u128::from_be(self.prefix).into()
}
pub fn asn(self) -> Asn {
u32::from_be(self.asn).into()
}
}
concrete!(IPv6Prefix);
// End of Data
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize)]
pub enum EndOfData {
V0(EndOfDataV0),
V1(EndOfDataV1),
}
impl EndOfData {
pub fn new(
version: u8,
session_id: u16,
serial_number: u32,
timing: Timing,
) -> Result<Self, io::Error> {
if version == 0 {
Ok(EndOfData::V0(EndOfDataV0::new(
version,
session_id,
serial_number,
)))
} else {
Ok(EndOfData::V1(EndOfDataV1::new(
version,
session_id,
serial_number,
timing,
)?))
}
}
}
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct EndOfDataV0 {
header: Header,
serial_number: u32,
}
impl EndOfDataV0 {
pub const PDU: u8 = 7;
pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self {
EndOfDataV0 {
header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V0_LEN),
serial_number: serial_number.to_be(),
}
}
pub fn serial_number(self) -> u32 {
u32::from_be(self.serial_number)
}
}
concrete!(EndOfDataV0);
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct EndOfDataV1 {
header: Header,
serial_number: u32,
refresh_interval: u32,
retry_interval: u32,
expire_interval: u32,
}
impl EndOfDataV1 {
pub const PDU: u8 = 7;
pub fn version(&self) -> u8 {
self.header.version()
}
pub fn session_id(&self) -> u16 {
self.header.session_id()
}
pub fn pdu(&self) -> u8 {
self.header.pdu()
}
pub fn size() -> u32 {
mem::size_of::<Self>() as u32
}
pub fn new(
version: u8,
session_id: u16,
serial_number: u32,
timing: Timing,
) -> Result<Self, io::Error> {
timing.validate()?;
Ok(EndOfDataV1 {
header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V1_LEN),
serial_number: serial_number.to_be(),
refresh_interval: timing.refresh.to_be(),
retry_interval: timing.retry.to_be(),
expire_interval: timing.expire.to_be(),
})
}
pub fn serial_number(self) -> u32 {
u32::from_be(self.serial_number)
}
pub fn timing(self) -> Timing {
Timing {
refresh: u32::from_be(self.refresh_interval),
retry: u32::from_be(self.retry_interval),
expire: u32::from_be(self.expire_interval),
}
}
fn validate(&self) -> Result<(), io::Error> {
self.timing().validate()
}
pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU type mismatch when expecting EndOfDataV1",
));
}
if res.header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid length for EndOfDataV1",
));
}
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
res.validate()?;
Ok(res)
}
pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: Header,
sock: &mut Sock,
) -> Result<Self, io::Error> {
if header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid length for EndOfDataV1 PDU",
));
}
let mut res = Self::default();
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
res.header = header;
res.validate()?;
Ok(res)
}
}
common!(EndOfDataV1);
// Cache Reset
#[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct CacheReset {
header: Header,
}
impl CacheReset {
pub const PDU: u8 = 8;
pub fn new(version: u8) -> Self {
CacheReset {
header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32),
}
}
}
concrete!(CacheReset);
// Error Report
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct ErrorReport {
octets: Vec<u8>,
}
impl ErrorReport {
/// The PDU type of an error PDU.
pub const PDU: u8 = 10;
const FIXED_PART_LEN: usize = HEADER_LEN + 2 * mem::size_of::<u32>();
/// Creates a new error PDU from components.
pub fn new(
version: u8,
error_code: u16,
pdu: impl AsRef<[u8]>,
text: impl AsRef<[u8]>,
) -> Self {
let pdu = pdu.as_ref();
let text = text.as_ref();
let max_payload_len = MAX_PDU_LEN as usize - Self::FIXED_PART_LEN;
let pdu_len = cmp::min(pdu.len(), max_payload_len);
let text_room = max_payload_len - pdu_len;
let text_len = cmp::min(text.len(), text_room);
let size = Self::FIXED_PART_LEN + pdu_len + text_len;
let header = Header::new(version, 10, error_code, u32::try_from(size).unwrap());
let mut octets = Vec::with_capacity(size);
octets.extend_from_slice(header.as_ref());
octets.extend_from_slice(u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref());
octets.extend_from_slice(&pdu[..pdu_len]);
octets.extend_from_slice(u32::try_from(text_len).unwrap().to_be_bytes().as_ref());
octets.extend_from_slice(&text[..text_len]);
ErrorReport { octets }
}
pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
let header = Header::read(sock).await?;
if header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU type mismatch when expecting ErrorReport",
));
}
Self::read_payload(header, sock).await
}
pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: Header,
sock: &mut Sock,
) -> Result<Self, io::Error> {
let total_len = header.pdu_len()?;
let Some(payload_len) = total_len.checked_sub(mem::size_of::<Header>()) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU size smaller than header size",
));
};
let mut octets = Vec::with_capacity(total_len);
octets.extend_from_slice(header.as_ref());
octets.resize(total_len, 0);
sock.read_exact(&mut octets[mem::size_of::<Header>()..])
.await?;
let res = ErrorReport { octets };
res.validate()?;
debug_assert_eq!(payload_len + mem::size_of::<Header>(), res.octets.len());
Ok(res)
}
pub fn version(&self) -> u8 {
self.header().version()
}
pub fn error_code(&self) -> Result<ErrorCode, u16> {
ErrorCode::try_from(self.header().error_code()).map_err(|_| self.header().error_code())
}
pub fn erroneous_pdu(&self) -> &[u8] {
&self.octets[self.erroneous_pdu_range()]
}
pub fn text(&self) -> &[u8] {
&self.octets[self.text_range()]
}
/// Skips over the payload of the error PDU.
pub async fn skip_payload<Sock: AsyncRead + Unpin>(
header: Header,
sock: &mut Sock,
) -> Result<(), io::Error> {
let Some(mut remaining) = header.pdu_len()?.checked_sub(mem::size_of::<Header>()) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU size smaller than header size",
));
};
let mut buf = [0u8; 1024];
while remaining > 0 {
let read_len = cmp::min(remaining, mem::size_of_val(&buf));
let read = sock.read(&mut buf[..read_len]).await?;
if read == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF while skipping ErrorReport payload",
));
}
remaining -= read;
}
Ok(())
}
/// Writes the PUD to a writer.
pub async fn write<A: AsyncWrite + Unpin>(&self, a: &mut A) -> Result<(), io::Error> {
a.write_all(self.as_ref()).await
}
fn header(&self) -> Header {
Header::from_raw(self.header_bytes()).expect("validated ErrorReport header")
}
fn header_bytes(&self) -> [u8; HEADER_LEN] {
self.octets[..HEADER_LEN]
.try_into()
.expect("ErrorReport shorter than header")
}
fn erroneous_pdu_len(&self) -> usize {
u32::from_be_bytes(
self.octets[Header::LEN..Header::LEN + 4]
.try_into()
.unwrap(),
) as usize
}
fn erroneous_pdu_range(&self) -> std::ops::Range<usize> {
let start = Header::LEN + 4;
let end = start + self.erroneous_pdu_len();
start..end
}
fn text_len_offset(&self) -> usize {
self.erroneous_pdu_range().end
}
fn text_len(&self) -> usize {
let offset = self.text_len_offset();
u32::from_be_bytes(self.octets[offset..offset + 4].try_into().unwrap()) as usize
}
fn text_range(&self) -> std::ops::Range<usize> {
let start = self.text_len_offset() + 4;
let end = start + self.text_len();
start..end
}
fn validate(&self) -> Result<(), io::Error> {
let header = Header::from_raw(self.header_bytes())?;
if header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected PDU type for ErrorReport",
));
}
let total_len = header.pdu_len()?;
if total_len != self.octets.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport length mismatch",
));
}
if self.octets.len() < Header::LEN + 8 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport too short",
));
}
let pdu_len = self.erroneous_pdu_len();
let text_len_offset = Header::LEN + 4 + pdu_len;
let Some(text_len_end) = text_len_offset.checked_add(4) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport length overflow",
));
};
if text_len_end > self.octets.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport truncated before error text length",
));
}
let text_len = u32::from_be_bytes(
self.octets[text_len_offset..text_len_end]
.try_into()
.unwrap(),
) as usize;
let Some(text_end) = text_len_end.checked_add(text_len) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport text overflow",
));
};
if text_end != self.octets.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport payload length mismatch",
));
}
if std::str::from_utf8(self.text()).is_err() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport text is not valid UTF-8",
));
}
Ok(())
}
}
// TODO: 补全
/// Router Key
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct RouterKey {
header: HeaderWithFlags,
flags: Flags,
ski: Ski,
asn: Asn,
subject_public_key_info: Arc<[u8]>,
}
impl RouterKey {
pub const PDU: u8 = 9;
const BASE_LEN: usize = HEADER_LEN + 20 + 4;
pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
let header = HeaderWithFlags::read(sock)
.await
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?;
if header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU type mismatch when expecting RouterKey",
));
}
Self::read_payload(header, sock).await
}
pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: HeaderWithFlags,
sock: &mut Sock,
) -> Result<Self, io::Error> {
let total_len = usize::try_from(header.length()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"RouterKey PDU too large for this system to handle",
)
})?;
if total_len < Self::BASE_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid length for RouterKey PDU",
));
}
let body_len = total_len - HEADER_LEN;
let mut body = vec![0u8; body_len];
sock.read_exact(&mut body).await?;
let mut ski = [0u8; 20];
ski.copy_from_slice(&body[..20]);
let asn = Asn::from(u32::from_be_bytes(body[20..24].try_into().unwrap()));
let subject_public_key_info = Arc::<[u8]>::from(body[24..].to_vec());
let res = Self {
header,
flags: header.flags(),
ski: Ski::from_bytes(ski),
asn,
subject_public_key_info,
};
res.validate()?;
Ok(res)
}
pub async fn write<A: AsyncWrite + Unpin>(&self, w: &mut A) -> Result<(), io::Error> {
let length = Self::BASE_LEN + self.subject_public_key_info.len();
let header =
HeaderWithFlags::new(self.header.version(), Self::PDU, self.flags, length as u32);
w.write_all(&[
header.version(),
header.pdu(),
header.flags().into_u8(),
ZERO_8,
])
.await?;
w.write_all(&(length as u32).to_be_bytes()).await?;
w.write_all(self.ski.as_ref()).await?;
w.write_all(&self.asn.into_u32().to_be_bytes()).await?;
w.write_all(&self.subject_public_key_info).await?;
Ok(())
}
pub fn new(
version: u8,
flags: Flags,
ski: Ski,
asn: Asn,
subject_public_key_info: Arc<[u8]>,
) -> Self {
let length = Self::BASE_LEN + subject_public_key_info.len();
Self {
header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32),
flags,
ski,
asn,
subject_public_key_info,
}
}
pub fn ski(&self) -> Ski {
self.ski
}
pub fn asn(&self) -> Asn {
self.asn
}
pub fn spki(&self) -> &[u8] {
&self.subject_public_key_info
}
fn validate(&self) -> Result<(), io::Error> {
if self.header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected PDU type for RouterKey",
));
}
if usize::try_from(self.header.length()).unwrap_or(0) < Self::BASE_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"RouterKey PDU shorter than fixed wire size",
));
}
if self.header.zero() != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"RouterKey reserved zero octet must be zero",
));
}
if self.header.flags().into_u8() & !0x01 != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"RouterKey flags use reserved bits",
));
}
if self.asn.into_u32() == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"RouterKey ASN must not be AS0",
));
}
if self.subject_public_key_info.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"RouterKey SPKI must not be empty",
));
}
Ok(())
}
}
// ASPA
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Aspa {
header: HeaderWithFlags,
customer_asn: u32,
provider_asns: Vec<u32>,
}
impl Aspa {
pub const PDU: u8 = 11;
const BASE_LEN: usize = HEADER_LEN + 4;
pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
let header = HeaderWithFlags::read(sock)
.await
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?;
if header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"PDU type mismatch when expecting ASPA",
));
}
Self::read_payload(header, sock).await
}
pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: HeaderWithFlags,
sock: &mut Sock,
) -> Result<Self, io::Error> {
let total_len = usize::try_from(header.length()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"ASPA PDU too large for this system to handle",
)
})?;
if total_len < Self::BASE_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid length for ASPA PDU",
));
}
if (total_len - Self::BASE_LEN) % 4 != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA provider list length must be a multiple of four octets",
));
}
let body_len = total_len - HEADER_LEN;
let mut body = vec![0u8; body_len];
sock.read_exact(&mut body).await?;
let customer_asn = u32::from_be_bytes(body[..4].try_into().unwrap());
let mut provider_asns = Vec::with_capacity((body.len() - 4) / 4);
for chunk in body[4..].chunks_exact(4) {
provider_asns.push(u32::from_be_bytes(chunk.try_into().unwrap()));
}
let res = Self {
header,
customer_asn,
provider_asns,
};
res.validate()?;
Ok(res)
}
pub async fn write<A: AsyncWrite + Unpin>(&self, w: &mut A) -> Result<(), io::Error> {
let length = Self::BASE_LEN + (self.provider_asns.len() * 4);
let header = HeaderWithFlags::new(
self.header.version(),
Self::PDU,
self.header.flags(),
length as u32,
);
w.write_all(&[
header.version(),
header.pdu(),
header.flags().into_u8(),
ZERO_8,
])
.await?;
w.write_all(&(length as u32).to_be_bytes()).await?;
w.write_all(&self.customer_asn.to_be_bytes()).await?;
for asn in &self.provider_asns {
w.write_all(&asn.to_be_bytes()).await?;
}
Ok(())
}
pub fn new(version: u8, flags: Flags, customer_asn: u32, provider_asns: Vec<u32>) -> Self {
let length = Self::BASE_LEN + (provider_asns.len() * 4);
Self {
header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32),
customer_asn,
provider_asns,
}
}
fn validate(&self) -> Result<(), io::Error> {
if self.header.pdu() != Self::PDU {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected PDU type for ASPA",
));
}
let total_len = usize::try_from(self.header.length()).unwrap_or(0);
if total_len < Self::BASE_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA PDU shorter than fixed wire size",
));
}
if (total_len - Self::BASE_LEN) % 4 != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA provider list length must be a multiple of four octets",
));
}
if self.header.zero() != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA reserved zero octet must be zero",
));
}
if self.header.flags().into_u8() & !0x01 != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA flags use reserved bits",
));
}
if self.customer_asn == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA customer ASN must not be AS0",
));
}
let is_announcement = self.header.flags().is_announce();
if is_announcement && self.provider_asns.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA announcement must contain at least one provider ASN",
));
}
if !is_announcement && !self.provider_asns.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA withdrawal must not contain provider ASNs",
));
}
if self.provider_asns.iter().any(|asn| *asn == 0) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA provider list must not contain AS0",
));
}
if self.provider_asns.windows(2).any(|pair| pair[0] >= pair[1]) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ASPA provider ASNs must be strictly increasing",
));
}
Ok(())
}
}
//--- AsRef and AsMut
impl AsRef<[u8]> for ErrorReport {
fn as_ref(&self) -> &[u8] {
self.octets.as_ref()
}
}
impl AsMut<[u8]> for ErrorReport {
fn as_mut(&mut self) -> &mut [u8] {
self.octets.as_mut()
}
}