use crate::data_model::resources::as_resources::Asn; use crate::rtr::error_type::ErrorCode; use crate::rtr::payload::{Ski, Timing}; use anyhow::Result; use std::io; use std::net::{Ipv4Addr, Ipv6Addr}; use std::sync::Arc; use std::{cmp, mem}; use tokio::io::AsyncWrite; use anyhow::bail; use serde::Serialize; use std::slice; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; pub const HEADER_LEN: usize = 8; pub const MAX_PDU_LEN: u32 = 65535; pub const IPV4_PREFIX_LEN: u32 = 20; pub const IPV6_PREFIX_LEN: u32 = 32; pub const END_OF_DATA_V0_LEN: u32 = 12; pub const END_OF_DATA_V1_LEN: u32 = 24; pub const ZERO_16: u16 = 0; pub const ZERO_8: u8 = 0; macro_rules! common { ( $type:ident ) => { #[allow(dead_code)] impl $type { /// Writes a value to a writer. pub async fn write(&self, a: &mut A) -> Result<(), io::Error> { a.write_all(self.as_ref()).await } } impl AsRef<[u8]> for $type { fn as_ref(&self) -> &[u8] { unsafe { slice::from_raw_parts(self as *const Self as *const u8, mem::size_of::()) } } } impl AsMut<[u8]> for $type { fn as_mut(&mut self) -> &mut [u8] { unsafe { slice::from_raw_parts_mut(self as *mut Self as *mut u8, mem::size_of::()) } } } }; } macro_rules! concrete { ( $type:ident ) => { common!($type); #[allow(dead_code)] impl $type { /// Returns the value of the version field of the header. pub fn version(&self) -> u8 { self.header.version() } /// Returns the value of the session field of the header. /// /// Note that this field is used for other purposes in some PDU /// types. pub fn session_id(&self) -> u16 { self.header.session_id() } pub fn pdu(&self) -> u8 { self.header.pdu() } /// Returns the PDU size. /// /// The size is returned as a `u32` since that type is used in /// the header. pub fn size() -> u32 { mem::size_of::() as u32 } /// Reads a value from a reader. /// /// If a value with a different PDU type is received, returns an /// error. pub async fn read(sock: &mut Sock) -> Result { let mut res = Self::default(); sock.read_exact(res.header.as_mut()).await?; if res.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, concat!("PDU type mismatch when expecting ", stringify!($type)), )); } if res.header.length() as usize != res.as_ref().len() { return Err(io::Error::new( io::ErrorKind::InvalidData, concat!("invalid length for ", stringify!($type)), )); } sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; Ok(res) } /// Tries to read a value from a reader. /// /// If a different PDU type is received, returns the header as /// the error case of the ok case. pub async fn try_read( sock: &mut Sock, ) -> Result, io::Error> { let mut res = Self::default(); sock.read_exact(res.header.as_mut()).await?; if res.header.pdu() == ErrorReport::PDU { // Since we should drop the session after an error, we // can safely ignore all the rest of the error for now. return Ok(Err(res.header)); } if res.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, concat!("PDU type mismatch when expecting ", stringify!($type)), )); } if res.header.length() as usize != res.as_ref().len() { return Err(io::Error::new( io::ErrorKind::InvalidData, concat!("invalid length for ", stringify!($type)), )); } sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; Ok(Ok(res)) } /// Reads only the payload part of a value from a reader. /// /// Assuming that the header was already read and is passed via /// `header`, the function reads the rest of the PUD from the /// reader and returns the complete value. pub async fn read_payload( header: Header, sock: &mut Sock, ) -> Result { if header.length() as usize != mem::size_of::() { return Err(io::Error::new( io::ErrorKind::InvalidData, concat!("invalid length for ", stringify!($type), " PDU"), )); } let mut res = Self::default(); sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; res.header = header; Ok(res) } } }; } // 所有PDU公共头部信息 #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct Header { version: u8, pdu: u8, session_id: u16, length: u32, } impl Header { const LEN: usize = mem::size_of::(); pub fn new(version: u8, pdu: u8, session: u16, length: u32) -> Self { Header { version, pdu, session_id: session.to_be(), length: length.to_be(), } } pub async fn read_raw( sock: &mut S, ) -> Result<[u8; HEADER_LEN], io::Error> { let mut buf = [0u8; HEADER_LEN]; sock.read_exact(&mut buf).await?; Ok(buf) } pub fn from_raw(buf: [u8; HEADER_LEN]) -> Result { let version = buf[0]; let pdu = buf[1]; let session_id = u16::from_be_bytes([buf[2], buf[3]]); let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); if length < HEADER_LEN as u32 { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid PDU length", )); } if length > MAX_PDU_LEN { return Err(io::Error::new(io::ErrorKind::InvalidData, "PDU too large")); } Ok(Self { version, pdu, session_id: session_id.to_be(), length: length.to_be(), }) } pub async fn read(sock: &mut S) -> Result { Self::from_raw(Self::read_raw(sock).await?) } pub fn version(self) -> u8 { self.version } pub fn pdu(self) -> u8 { self.pdu } pub fn session_id(self) -> u16 { u16::from_be(self.session_id) } pub fn length(self) -> u32 { u32::from_be(self.length) } pub fn pdu_len(self) -> Result { usize::try_from(self.length()).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, "PDU too large for this system to handle", ) }) } pub fn error_code(self) -> u16 { debug_assert_eq!(self.pdu(), ErrorReport::PDU); self.session_id() } } common!(Header); #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct HeaderWithFlags { version: u8, pdu: u8, flags: u8, zero: u8, length: u32, } impl HeaderWithFlags { pub fn new(version: u8, pdu: u8, flags: Flags, length: u32) -> Self { HeaderWithFlags { version, pdu, flags: flags.into_u8(), zero: ZERO_8, length: length.to_be(), } } pub async fn read(sock: &mut S) -> Result { let mut buf = [0u8; HEADER_LEN]; // 1. 精确读取 8 字节 sock.read_exact(&mut buf).await?; // 2. 手动解析(大端) let version = buf[0]; let pdu = buf[1]; let flags = buf[2]; let zero = buf[3]; let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); // 3. 基础合法性校验 if length < HEADER_LEN as u32 { bail!("Invalid PDU length"); } // 限制最大长度 if length > MAX_PDU_LEN { bail!("PDU too large"); } Ok(Self { version, pdu, flags, zero, length: length.to_be(), }) } pub fn version(self) -> u8 { self.version } pub fn pdu(self) -> u8 { self.pdu } pub fn flags(self) -> Flags { Flags(self.flags) } pub fn zero(self) -> u8 { self.zero } pub fn length(self) -> u32 { u32::from_be(self.length) } } // Serial Notify #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct SerialNotify { header: Header, serial_number: u32, } impl SerialNotify { pub const PDU: u8 = 0; pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { SerialNotify { header: Header::new(version, Self::PDU, session_id, Self::size()), serial_number: serial_number.to_be(), } } pub fn serial_number(self) -> u32 { u32::from_be(self.serial_number) } } concrete!(SerialNotify); // Serial Query #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct SerialQuery { header: Header, serial_number: u32, } impl SerialQuery { pub const PDU: u8 = 1; pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { SerialQuery { header: Header::new(version, Self::PDU, session_id, Self::size()), serial_number: serial_number.to_be(), } } pub fn serial_number(self) -> u32 { u32::from_be(self.serial_number) } } concrete!(SerialQuery); // Reset Query #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct ResetQuery { header: Header, } impl ResetQuery { pub const PDU: u8 = 2; pub fn new(version: u8) -> Self { ResetQuery { header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32), } } } concrete!(ResetQuery); // Cache Response #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct CacheResponse { header: Header, } impl CacheResponse { pub const PDU: u8 = 3; pub fn new(version: u8, session_id: u16) -> Self { CacheResponse { header: Header::new(version, Self::PDU, session_id, HEADER_LEN as u32), } } } concrete!(CacheResponse); // Flags #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct Flags(u8); impl Flags { pub fn new(raw: u8) -> Self { Self(raw) } pub fn is_announce(self) -> bool { self.0 & 0x01 == 1 } pub fn is_withdraw(self) -> bool { !self.is_announce() } pub fn into_u8(self) -> u8 { self.0 } } // IPv4 Prefix #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct IPv4Prefix { header: Header, flags: Flags, prefix_len: u8, max_len: u8, zero: u8, prefix: u32, asn: u32, } impl IPv4Prefix { pub const PDU: u8 = 4; pub fn new( version: u8, flags: Flags, prefix_len: u8, max_len: u8, prefix: Ipv4Addr, asn: Asn, ) -> Self { IPv4Prefix { header: Header::new(version, Self::PDU, ZERO_16, IPV4_PREFIX_LEN), flags, prefix_len, max_len, zero: ZERO_8, prefix: u32::from(prefix).to_be(), asn: asn.into_u32().to_be(), } } pub fn flag(self) -> Flags { self.flags } pub fn prefix_len(self) -> u8 { self.prefix_len } pub fn max_len(self) -> u8 { self.max_len } pub fn prefix(self) -> Ipv4Addr { u32::from_be(self.prefix).into() } pub fn asn(self) -> Asn { u32::from_be(self.asn).into() } } concrete!(IPv4Prefix); // IPv6 Prefix #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct IPv6Prefix { header: Header, flags: Flags, prefix_len: u8, max_len: u8, zero: u8, prefix: u128, asn: u32, } impl IPv6Prefix { pub const PDU: u8 = 6; pub fn new( version: u8, flags: Flags, prefix_len: u8, max_len: u8, prefix: Ipv6Addr, asn: Asn, ) -> Self { IPv6Prefix { header: Header::new(version, Self::PDU, ZERO_16, IPV6_PREFIX_LEN), flags, prefix_len, max_len, zero: ZERO_8, prefix: u128::from(prefix).to_be(), asn: asn.into_u32().to_be(), } } pub fn flag(self) -> Flags { self.flags } pub fn prefix_len(self) -> u8 { self.prefix_len } pub fn max_len(self) -> u8 { self.max_len } pub fn prefix(self) -> Ipv6Addr { u128::from_be(self.prefix).into() } pub fn asn(self) -> Asn { u32::from_be(self.asn).into() } } concrete!(IPv6Prefix); // End of Data #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize)] pub enum EndOfData { V0(EndOfDataV0), V1(EndOfDataV1), } impl EndOfData { pub fn new( version: u8, session_id: u16, serial_number: u32, timing: Timing, ) -> Result { if version == 0 { Ok(EndOfData::V0(EndOfDataV0::new( version, session_id, serial_number, ))) } else { Ok(EndOfData::V1(EndOfDataV1::new( version, session_id, serial_number, timing, )?)) } } } #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct EndOfDataV0 { header: Header, serial_number: u32, } impl EndOfDataV0 { pub const PDU: u8 = 7; pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { EndOfDataV0 { header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V0_LEN), serial_number: serial_number.to_be(), } } pub fn serial_number(self) -> u32 { u32::from_be(self.serial_number) } } concrete!(EndOfDataV0); #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct EndOfDataV1 { header: Header, serial_number: u32, refresh_interval: u32, retry_interval: u32, expire_interval: u32, } impl EndOfDataV1 { pub const PDU: u8 = 7; pub fn version(&self) -> u8 { self.header.version() } pub fn session_id(&self) -> u16 { self.header.session_id() } pub fn pdu(&self) -> u8 { self.header.pdu() } pub fn size() -> u32 { mem::size_of::() as u32 } pub fn new( version: u8, session_id: u16, serial_number: u32, timing: Timing, ) -> Result { timing.validate()?; Ok(EndOfDataV1 { header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V1_LEN), serial_number: serial_number.to_be(), refresh_interval: timing.refresh.to_be(), retry_interval: timing.retry.to_be(), expire_interval: timing.expire.to_be(), }) } pub fn serial_number(self) -> u32 { u32::from_be(self.serial_number) } pub fn timing(self) -> Timing { Timing { refresh: u32::from_be(self.refresh_interval), retry: u32::from_be(self.retry_interval), expire: u32::from_be(self.expire_interval), } } fn validate(&self) -> Result<(), io::Error> { self.timing().validate() } pub async fn read(sock: &mut Sock) -> Result { let mut res = Self::default(); sock.read_exact(res.header.as_mut()).await?; if res.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU type mismatch when expecting EndOfDataV1", )); } if res.header.length() as usize != mem::size_of::() { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid length for EndOfDataV1", )); } sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; res.validate()?; Ok(res) } pub async fn read_payload( header: Header, sock: &mut Sock, ) -> Result { if header.length() as usize != mem::size_of::() { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid length for EndOfDataV1 PDU", )); } let mut res = Self::default(); sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; res.header = header; res.validate()?; Ok(res) } } common!(EndOfDataV1); // Cache Reset #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct CacheReset { header: Header, } impl CacheReset { pub const PDU: u8 = 8; pub fn new(version: u8) -> Self { CacheReset { header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32), } } } concrete!(CacheReset); // Error Report #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct ErrorReport { octets: Vec, } impl ErrorReport { /// The PDU type of an error PDU. pub const PDU: u8 = 10; const FIXED_PART_LEN: usize = HEADER_LEN + 2 * mem::size_of::(); /// Creates a new error PDU from components. pub fn new( version: u8, error_code: u16, pdu: impl AsRef<[u8]>, text: impl AsRef<[u8]>, ) -> Self { let pdu = pdu.as_ref(); let text = text.as_ref(); let max_payload_len = MAX_PDU_LEN as usize - Self::FIXED_PART_LEN; let pdu_len = cmp::min(pdu.len(), max_payload_len); let text_room = max_payload_len - pdu_len; let text_len = cmp::min(text.len(), text_room); let size = Self::FIXED_PART_LEN + pdu_len + text_len; let header = Header::new(version, 10, error_code, u32::try_from(size).unwrap()); let mut octets = Vec::with_capacity(size); octets.extend_from_slice(header.as_ref()); octets.extend_from_slice(u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref()); octets.extend_from_slice(&pdu[..pdu_len]); octets.extend_from_slice(u32::try_from(text_len).unwrap().to_be_bytes().as_ref()); octets.extend_from_slice(&text[..text_len]); ErrorReport { octets } } pub async fn read(sock: &mut Sock) -> Result { let header = Header::read(sock).await?; if header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU type mismatch when expecting ErrorReport", )); } Self::read_payload(header, sock).await } pub async fn read_payload( header: Header, sock: &mut Sock, ) -> Result { let total_len = header.pdu_len()?; let Some(payload_len) = total_len.checked_sub(mem::size_of::
()) else { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU size smaller than header size", )); }; let mut octets = Vec::with_capacity(total_len); octets.extend_from_slice(header.as_ref()); octets.resize(total_len, 0); sock.read_exact(&mut octets[mem::size_of::
()..]) .await?; let res = ErrorReport { octets }; res.validate()?; debug_assert_eq!(payload_len + mem::size_of::
(), res.octets.len()); Ok(res) } pub fn version(&self) -> u8 { self.header().version() } pub fn error_code(&self) -> Result { ErrorCode::try_from(self.header().error_code()).map_err(|_| self.header().error_code()) } pub fn erroneous_pdu(&self) -> &[u8] { &self.octets[self.erroneous_pdu_range()] } pub fn text(&self) -> &[u8] { &self.octets[self.text_range()] } /// Skips over the payload of the error PDU. pub async fn skip_payload( header: Header, sock: &mut Sock, ) -> Result<(), io::Error> { let Some(mut remaining) = header.pdu_len()?.checked_sub(mem::size_of::
()) else { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU size smaller than header size", )); }; let mut buf = [0u8; 1024]; while remaining > 0 { let read_len = cmp::min(remaining, mem::size_of_val(&buf)); let read = sock.read(&mut buf[..read_len]).await?; if read == 0 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "unexpected EOF while skipping ErrorReport payload", )); } remaining -= read; } Ok(()) } /// Writes the PUD to a writer. pub async fn write(&self, a: &mut A) -> Result<(), io::Error> { a.write_all(self.as_ref()).await } fn header(&self) -> Header { Header::from_raw(self.header_bytes()).expect("validated ErrorReport header") } fn header_bytes(&self) -> [u8; HEADER_LEN] { self.octets[..HEADER_LEN] .try_into() .expect("ErrorReport shorter than header") } fn erroneous_pdu_len(&self) -> usize { u32::from_be_bytes( self.octets[Header::LEN..Header::LEN + 4] .try_into() .unwrap(), ) as usize } fn erroneous_pdu_range(&self) -> std::ops::Range { let start = Header::LEN + 4; let end = start + self.erroneous_pdu_len(); start..end } fn text_len_offset(&self) -> usize { self.erroneous_pdu_range().end } fn text_len(&self) -> usize { let offset = self.text_len_offset(); u32::from_be_bytes(self.octets[offset..offset + 4].try_into().unwrap()) as usize } fn text_range(&self) -> std::ops::Range { let start = self.text_len_offset() + 4; let end = start + self.text_len(); start..end } fn validate(&self) -> Result<(), io::Error> { let header = Header::from_raw(self.header_bytes())?; if header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "unexpected PDU type for ErrorReport", )); } let total_len = header.pdu_len()?; if total_len != self.octets.len() { return Err(io::Error::new( io::ErrorKind::InvalidData, "ErrorReport length mismatch", )); } if self.octets.len() < Header::LEN + 8 { return Err(io::Error::new( io::ErrorKind::InvalidData, "ErrorReport too short", )); } let pdu_len = self.erroneous_pdu_len(); let text_len_offset = Header::LEN + 4 + pdu_len; let Some(text_len_end) = text_len_offset.checked_add(4) else { return Err(io::Error::new( io::ErrorKind::InvalidData, "ErrorReport length overflow", )); }; if text_len_end > self.octets.len() { return Err(io::Error::new( io::ErrorKind::InvalidData, "ErrorReport truncated before error text length", )); } let text_len = u32::from_be_bytes( self.octets[text_len_offset..text_len_end] .try_into() .unwrap(), ) as usize; let Some(text_end) = text_len_end.checked_add(text_len) else { return Err(io::Error::new( io::ErrorKind::InvalidData, "ErrorReport text overflow", )); }; if text_end != self.octets.len() { return Err(io::Error::new( io::ErrorKind::InvalidData, "ErrorReport payload length mismatch", )); } if std::str::from_utf8(self.text()).is_err() { return Err(io::Error::new( io::ErrorKind::InvalidData, "ErrorReport text is not valid UTF-8", )); } Ok(()) } } // TODO: 补全 /// Router Key #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct RouterKey { header: HeaderWithFlags, flags: Flags, ski: Ski, asn: Asn, subject_public_key_info: Arc<[u8]>, } impl RouterKey { pub const PDU: u8 = 9; const BASE_LEN: usize = HEADER_LEN + 20 + 4; pub async fn read(sock: &mut Sock) -> Result { let header = HeaderWithFlags::read(sock) .await .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?; if header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU type mismatch when expecting RouterKey", )); } Self::read_payload(header, sock).await } pub async fn read_payload( header: HeaderWithFlags, sock: &mut Sock, ) -> Result { let total_len = usize::try_from(header.length()).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, "RouterKey PDU too large for this system to handle", ) })?; if total_len < Self::BASE_LEN { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid length for RouterKey PDU", )); } let body_len = total_len - HEADER_LEN; let mut body = vec![0u8; body_len]; sock.read_exact(&mut body).await?; let mut ski = [0u8; 20]; ski.copy_from_slice(&body[..20]); let asn = Asn::from(u32::from_be_bytes(body[20..24].try_into().unwrap())); let subject_public_key_info = Arc::<[u8]>::from(body[24..].to_vec()); let res = Self { header, flags: header.flags(), ski: Ski::from_bytes(ski), asn, subject_public_key_info, }; res.validate()?; Ok(res) } pub async fn write(&self, w: &mut A) -> Result<(), io::Error> { let length = Self::BASE_LEN + self.subject_public_key_info.len(); let header = HeaderWithFlags::new(self.header.version(), Self::PDU, self.flags, length as u32); w.write_all(&[ header.version(), header.pdu(), header.flags().into_u8(), ZERO_8, ]) .await?; w.write_all(&(length as u32).to_be_bytes()).await?; w.write_all(self.ski.as_ref()).await?; w.write_all(&self.asn.into_u32().to_be_bytes()).await?; w.write_all(&self.subject_public_key_info).await?; Ok(()) } pub fn new( version: u8, flags: Flags, ski: Ski, asn: Asn, subject_public_key_info: Arc<[u8]>, ) -> Self { let length = Self::BASE_LEN + subject_public_key_info.len(); Self { header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32), flags, ski, asn, subject_public_key_info, } } pub fn ski(&self) -> Ski { self.ski } pub fn asn(&self) -> Asn { self.asn } pub fn spki(&self) -> &[u8] { &self.subject_public_key_info } fn validate(&self) -> Result<(), io::Error> { if self.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "unexpected PDU type for RouterKey", )); } if usize::try_from(self.header.length()).unwrap_or(0) < Self::BASE_LEN { return Err(io::Error::new( io::ErrorKind::InvalidData, "RouterKey PDU shorter than fixed wire size", )); } if self.header.zero() != 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "RouterKey reserved zero octet must be zero", )); } if self.header.flags().into_u8() & !0x01 != 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "RouterKey flags use reserved bits", )); } if self.asn.into_u32() == 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "RouterKey ASN must not be AS0", )); } if self.subject_public_key_info.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidData, "RouterKey SPKI must not be empty", )); } Ok(()) } } // ASPA #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct Aspa { header: HeaderWithFlags, customer_asn: u32, provider_asns: Vec, } impl Aspa { pub const PDU: u8 = 11; const BASE_LEN: usize = HEADER_LEN + 4; pub async fn read(sock: &mut Sock) -> Result { let header = HeaderWithFlags::read(sock) .await .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?; if header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU type mismatch when expecting ASPA", )); } Self::read_payload(header, sock).await } pub async fn read_payload( header: HeaderWithFlags, sock: &mut Sock, ) -> Result { let total_len = usize::try_from(header.length()).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, "ASPA PDU too large for this system to handle", ) })?; if total_len < Self::BASE_LEN { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid length for ASPA PDU", )); } if (total_len - Self::BASE_LEN) % 4 != 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA provider list length must be a multiple of four octets", )); } let body_len = total_len - HEADER_LEN; let mut body = vec![0u8; body_len]; sock.read_exact(&mut body).await?; let customer_asn = u32::from_be_bytes(body[..4].try_into().unwrap()); let mut provider_asns = Vec::with_capacity((body.len() - 4) / 4); for chunk in body[4..].chunks_exact(4) { provider_asns.push(u32::from_be_bytes(chunk.try_into().unwrap())); } let res = Self { header, customer_asn, provider_asns, }; res.validate()?; Ok(res) } pub async fn write(&self, w: &mut A) -> Result<(), io::Error> { let length = Self::BASE_LEN + (self.provider_asns.len() * 4); let header = HeaderWithFlags::new( self.header.version(), Self::PDU, self.header.flags(), length as u32, ); w.write_all(&[ header.version(), header.pdu(), header.flags().into_u8(), ZERO_8, ]) .await?; w.write_all(&(length as u32).to_be_bytes()).await?; w.write_all(&self.customer_asn.to_be_bytes()).await?; for asn in &self.provider_asns { w.write_all(&asn.to_be_bytes()).await?; } Ok(()) } pub fn new(version: u8, flags: Flags, customer_asn: u32, provider_asns: Vec) -> Self { let length = Self::BASE_LEN + (provider_asns.len() * 4); Self { header: HeaderWithFlags::new(version, Self::PDU, flags, length as u32), customer_asn, provider_asns, } } fn validate(&self) -> Result<(), io::Error> { if self.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "unexpected PDU type for ASPA", )); } let total_len = usize::try_from(self.header.length()).unwrap_or(0); if total_len < Self::BASE_LEN { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA PDU shorter than fixed wire size", )); } if (total_len - Self::BASE_LEN) % 4 != 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA provider list length must be a multiple of four octets", )); } if self.header.zero() != 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA reserved zero octet must be zero", )); } if self.header.flags().into_u8() & !0x01 != 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA flags use reserved bits", )); } if self.customer_asn == 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA customer ASN must not be AS0", )); } let is_announcement = self.header.flags().is_announce(); if is_announcement && self.provider_asns.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA announcement must contain at least one provider ASN", )); } if !is_announcement && !self.provider_asns.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA withdrawal must not contain provider ASNs", )); } if self.provider_asns.iter().any(|asn| *asn == 0) { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA provider list must not contain AS0", )); } if self.provider_asns.windows(2).any(|pair| pair[0] >= pair[1]) { return Err(io::Error::new( io::ErrorKind::InvalidData, "ASPA provider ASNs must be strictly increasing", )); } Ok(()) } } //--- AsRef and AsMut impl AsRef<[u8]> for ErrorReport { fn as_ref(&self) -> &[u8] { self.octets.as_ref() } } impl AsMut<[u8]> for ErrorReport { fn as_mut(&mut self) -> &mut [u8] { self.octets.as_mut() } }