diff --git a/Cargo.toml b/Cargo.toml index a15d38d..fa7ca1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,14 @@ url = "2.5.8" asn1-rs = "0.7.1" asn1-rs-derive = "0.6.0" asn1 = "0.23.0" +arc-swap = "1.7.0" +chrono = "0.4.44" +bytes = "1.11.1" +tokio = { version = "1.49.0", features = ["full"] } +rand = "0.10.0" +rocksdb = "0.21" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +anyhow = "1" +bincode = "3.0.0" +tracing = "0.1.44" \ No newline at end of file diff --git a/specs/10_slurm.md b/specs/10_slurm.md new file mode 100644 index 0000000..698ce2a --- /dev/null +++ b/specs/10_slurm.md @@ -0,0 +1,236 @@ +# 10. SLURM(Simplified Local Internet Number Resource Management with the RPKI) + +## 10.1 对象定位 + +SLURM是一个JSON文件,允许 RPKI 依赖方在本地“覆盖/修正/忽略”来自上游RPKI数据的内容,而不需要修改或伪造原始RPKI对象。 + +## 10.2 数据格式 (RFC 8416 §3) + +### SLURM + +SLURM是一个只包含一个JSON对象的文件。格式要求如下(RFC 8416 §3.2): + +```text +A SLURM file consists of a single JSON object containing the +following members: + o A "slurmVersion" member that MUST be set to 1, encoded as a number + o A "validationOutputFilters" member (Section 3.3), whose value is + an object. The object MUST contain exactly two members: + * A "prefixFilters" member, whose value is described in + Section 3.3.1. + * A "bgpsecFilters" member, whose value is described in + Section 3.3.2. + o A "locallyAddedAssertions" member (Section 3.4), whose value is an + object. The object MUST contain exactly two members: + * A "prefixAssertions" member, whose value is described in + Section 3.4.1. + * A "bgpsecAssertions" member, whose value is described in + Section 3.4.2. +``` + +一个空的SLURM json结构体如下: + +```json +{ + "slurmVersion": 1, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [] + } +} +``` + +### prefixFilters +其中`prefixFilters`格式要求如下(RFC 8416 §3.3.1): + +```text +The above is expressed as a value of the "prefixFilters" member, as +an array of zero or more objects. Each object MUST contain either 1) +one of the following members or 2) one of each of the following +members. + o A "prefix" member, whose value is a string representing either an + IPv4 prefix (see Section 3.1 of [RFC4632]) or an IPv6 prefix (see + [RFC5952]). + o An "asn" member, whose value is a number. + In addition, each object MAY contain one optional "comment" member, + whose value is a string. +``` +示例: +```json +"prefixFilters": [ + { + "prefix": "192.0.2.0/24", + "comment": "All VRPs encompassed by prefix" + }, + { + "asn": 64496, + "comment": "All VRPs matching ASN" + }, + { + "prefix": "198.51.100.0/24", + "asn": 64497, + "comment": "All VRPs encompassed by prefix, matching ASN" + } +] +``` + +### bgpsecFilters +`bgpsecFilters`格式要求如下(RFC 8416 §3.3.2) + +```text +The above is expressed as a value of the "bgpsecFilters" member, as +an array of zero or more objects. Each object MUST contain one of +either, or one each of both following members: + o An "asn" member, whose value is a number + o An "SKI" member, whose value is the Base64 encoding without + trailing ’=’ (Section 5 of [RFC4648]) of the certificate’s Subject + Key Identifier as described in Section 4.8.2 of [RFC6487]. (This + is the value of the ASN.1 OCTET STRING without the ASN.1 tag or + length fields.) +In addition, each object MAY contain one optional "comment" member, +whose value is a string. +``` + +示例: +```json +"bgpsecFilters": [ + { + "asn": 64496, + "comment": "All keys for ASN" + }, + { + "SKI": "", + "comment": "Key matching Router SKI" + }, + { + "asn": 64497, + "SKI": "", + "comment": "Key for ASN 64497 matching Router SKI" + } +] +``` + +### prefixAssertions +`prefixAssertions`格式要求如下(RFC 8416 §3.4.1) +```text +The above is expressed as a value of the "prefixAssertions" member, +as an array of zero or more objects. Each object MUST contain one of +each of the following members: + o A "prefix" member, whose value is a string representing either an + IPv4 prefix (see Section 3.1 of [RFC4632]) or an IPv6 prefix (see + [RFC5952]). + o An "asn" member, whose value is a number. +In addition, each object MAY contain one of each of the following +members: + o A "maxPrefixLength" member, whose value is a number. + o A "comment" member, whose value is a string. +``` + +示例: +```json +"prefixAssertions": [ + { + "asn": 64496, + "prefix": "198.51.100.0/24", + "comment": "My other important route" + }, + { + "asn": 64496, + "prefix": "2001:DB8::/32", + "maxPrefixLength": 48, + "comment": "My other important de-aggregated routes" + } +] +``` + +### bgpsecAssertions +`bgpsecAssertions`格式要求如下(RFC 8416 §3.4.2) +```text +The above is expressed as a value of the "bgpsecAssertions" member, +as an array of zero or more objects. Each object MUST contain one +each of all of the following members: + o An "asn" member, whose value is a number. + o An "SKI" member, whose value is the Base64 encoding without + trailing ’=’ (Section 5 of [RFC4648]) of the certificate’s Subject + Key Identifier as described in Section 4.8.2 of [RFC6487] (This is + the value of the ASN.1 OCTET STRING without the ASN.1 tag or + length fields.) + o A "routerPublicKey" member, whose value is the Base64 encoding + without trailing ’=’ (Section 5 of [RFC4648]) of the equivalent to + the subjectPublicKeyInfo value of the router certificate’s public + key, as described in [RFC8208]. This is the full ASN.1 DER + encoding of the subjectPublicKeyInfo, including the ASN.1 tag and + length values of the subjectPublicKeyInfo SEQUENCE. +``` +示例: +```json +"bgpsecAssertions": [ + { + "asn": 64496, + "SKI": "", + "routerPublicKey": "", + "comment": "My known key for my important ASN" + } +] +``` + +## 10.3 抽象数据结构 + +### SLURM +| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | +|---------------------------|------------------------|---------|---------|---------------| +| slurm_version | number | SLURM版本 | 版本必须为1 | RFC 8416 §3.2 | +| validation_output_filters | ValidationOutputFilter | 过滤条件 | | | +| locally_added_assertions | LocallyAddedAssertions | 本地添加断言 | | | + +### ValidationOutputFilter +| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | +|----------------|-------------------|-----------|---------|---------------| +| prefix_filters | Vec | 前缀过滤 | 可以为空数组 | RFC 8416 §3.3 | +| bgpsec_filters | Vec | BGPsec过滤 | 可以为空数组 | RFC 8416 §3.3 | + +### LocallyAddedAssertions +| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | +|-------------------|----------------------|-----------|---------|---------------| +| prefix_assertions | Vec | 前缀断言 | 可以为空数组 | RFC 8416 §3.4 | +| bgpsec_assertions | Vec | BGPsec断言 | 可以为空数组 | RFC 8416 §3.4 | + +### PrefixFilter +| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | +|---------|--------|------|--------------------------------|-----------------| +| prefix | string | 前缀 | IPv4前缀或IPv6前缀,prefix和asn至少存在一个 | RFC 8416 §3.3.1 | +| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.3.1 | +| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.3.1 | + +### BgpsecFilter +| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | +|---------|--------|------|------------------|------------------| +| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.3.1 | +| ski | u8 | | 证书的SKI | RFC 8416 §3.3.1 | +| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.3.1 | + +### PrefixAssertion +| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | +|-------------------|--------|--------|---------------|-----------------| +| prefix | string | 前缀 | IPv4前缀或IPv6前缀 | RFC 8416 §3.4.1 | +| asn | number | ASN | | RFC 8416 §3.4.1 | +| max_prefix_length | number | 最大前缀长度 | 可选字段 | RFC 8416 §3.4.1 | +| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.4.1 | + + +### BgpsecAssertion +| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | +|-------------------|--------|--------|------------------|-----------------| +| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.4.2 | +| ski | u8 | | 证书的SKI | RFC 8416 §3.4.2 | +| router_public_key | u8 | 证书的SKI | | RFC 8416 §3.4.2 | +| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.4.2 | + +> 注:BGPsec部分可以在第一版考虑先留空 + +## 10.4 规则 + diff --git a/specs/11_rtr.md b/specs/11_rtr.md new file mode 100644 index 0000000..8a6051a --- /dev/null +++ b/specs/11_rtr.md @@ -0,0 +1,65 @@ +# 11. RTR (The Resource Public Key Infrastructure (RPKI) to Router Protocol) + +## 11.1 Cache Server + +### 11.1.1 功能需求 + +- 支持Full Sync(Reset Query) +- 支持Incremental Sync(Serial Query) +- 支持多客户端并发 +- 支持Serial递增 +- 保留一定数量的delta +- 支持原子更新 + +### 11.1.2 架构设计 +采用一级缓存+二级缓存并存的方式。 + +![img.png](img/img.png) + +其中,一级缓存为运行时缓存,主要职责: +- 存储当前完整的snapshot +- 历史Delta队列管理 +- Serial管理 +- RTR查询响应 + +二级缓存为持久化缓存,主要职责: +- snapshot持久化 +- 缓存重启后的快速恢复(snapshot和serial) +- 不参与实时查询 +- 异步写入 + +### 11.1.3 核心数据结构设计 + +#### 11.1.3.1 总cache +```rust +struct RtrCache { + serial: AtomicU32, + snapshot: ArcSwap, + deltas: RwLock>>, + max_delta: usize, +} +``` + +#### 11.1.3.2 Snapshot +```rust +struct Snapshot { + origins: Vec, + router_keys: Vec, + aspas: Vec, + created_at: Instant, +} +``` + +#### 11.1.3.3 Delta +```rust +struct Delta { + serial: u32, + announced: Vec, + withdrawn: Vec, +} +``` + + +## 11.2 Transport + +初版实现RTR over TLS(可外网)和RTR over TCP(内网)两种方式。 \ No newline at end of file diff --git a/specs/img/img.png b/specs/img/img.png new file mode 100644 index 0000000..056efe5 Binary files /dev/null and b/specs/img/img.png differ diff --git a/src/data_model/mod.rs b/src/data_model/mod.rs index 946ad86..2234b3a 100644 --- a/src/data_model/mod.rs +++ b/src/data_model/mod.rs @@ -2,5 +2,5 @@ pub mod crl; mod rc; mod tal; mod ta; -mod resources; +pub mod resources; mod oids; diff --git a/src/data_model/rc.rs b/src/data_model/rc.rs index 4dfdd3e..f5851d7 100644 --- a/src/data_model/rc.rs +++ b/src/data_model/rc.rs @@ -1,639 +1,639 @@ -use asn1::{parse, BitString}; -use der_parser::asn1_rs::Tag; -use der_parser::num_bigint::BigUint; -use url::Url; -use time::OffsetDateTime; -use x509_parser::x509::AlgorithmIdentifier; -use x509_parser::prelude::{Validity, KeyUsage, X509Certificate, FromDer, - X509Version, X509Extension, ParsedExtension, - DistributionPointName, GeneralName}; -use crate::data_model::crl::CrlDecodeError; -use crate::data_model::resources::ip_resources::{Afi, IPAddrBlocks, IPAddress, IPAddressChoice, - IPAddressOrRange, IPAddressPrefix, IPAddressRange, - IPAddressFamily}; -use crate::data_model::resources::as_resources::ASIdentifiers; -use crate::data_model::oids; - - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct SubjectPublicKeyInfo { - pub algorithm_oid: String, - pub subject_public_key: u8, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct AccessDescription { - pub access_method_oid: String, - pub access_location: Url, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct PolicyInformation { - pub policy_oid: String, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct RcExtension { - pub basic_constraints: bool, - pub subject_key_identifier: u8, - pub authority_key_identifier: u8, - pub key_usage: KeyUsage, - pub extended_key_usage_oid: u8, - pub crl_distribution_points: Vec, - pub authority_info_access: Vec, - pub subject_info_access: Vec, - pub certificate_policies: Vec, - pub ip_resource: IPAddrBlocks, - pub as_resource: ASIdentifiers, - -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct ResourceCert { - /// 证书原始DER内容 - pub cert_der: Vec, - - /// 基本证书信息 - pub version: u32, - pub serial_number: BigUint, - pub signature_algorithm_oid: String, - pub issuer_dn: String, - pub subject_dn: String, - pub validity: Validity, - pub subject_public_key_info: SubjectPublicKeyInfo, - pub extensions: RcExtension, -} - - -#[derive(Debug, thiserror::Error)] -pub enum ResourceCertError { - #[error("X.509 parse resource cert error: {0}")] - ParseCert(String), - - #[error("trailing bytes after CRL DER: {0} bytes")] - TrailingBytes(usize), - - #[error("invalid version {0}")] - InvalidVersion(u32), - - #[error("signatureAlgorithm does not match tbsCertificate.signature")] - SignatureAlgorithmMismatch, - - #[error("unsupported signature algorithm")] - UnsupportedSignatureAlgorithm, - - #[error("invalid Cert signature algorithm parameters")] - InvalidSignatureParameters, - - #[error("invalid Cert validity range")] - InvalidValidityRange, - - #[error("Cert not yet valid")] - NotYetValid, - - #[error("expired")] - Expired, - - #[error("Critical error, {0} should be {1}")] - CriticalError(String, String), - - #[error("Duplicate Extension: {0}")] - DuplicateExtension(String), - - #[error("AKI missing keyIdentifier")] - AkiMissingKeyIdentifier, - - #[error("Unexpected parameter: {0}")] - UnexceptedParameter(String), - - #[error("Missing parameter: {0}")] - MissingParameter(String), - - #[error("CRL DP invalid distributionPointName: {0}")] - CrlDpInvalidDistributionPointName(String), - - #[error("CRL DP unexpected distributionPointType: {0}")] - CrlDpUnexpectedDistributionPointType(String), - - #[error("invalid URI: {0}")] - InvalidUri(String), - - #[error("Unsupported General Name in {0}")] - UnsupportedGeneralName(String), - - #[error("Unsupported CRL Distribution Point")] - UnsupportedCrlDistributionPoint, - - #[error("Invalid Access Location Type")] - InvalidAccessLocationType, - - #[error("Empty AuthorityInfoAccess!")] - EmptyAuthorityInfoAccess, - - #[error("Certificate Policies must exists one policy")] - CertificatePoliciesTooMany, - - #[error("Certificate Policies invalid")] - CertificatePoliciesInvalid, - -} - -impl ResourceCert{ - pub fn from_der(cert_der: &[u8]) -> Result { - let (rem, x509_rc) = X509Certificate::from_der(cert_der) - .map_err(|e| ResourceCertError::ParseCert(e.to_string()))?; - - if !rem.is_empty() { - return Err(ResourceCertError::TrailingBytes(rem.len())); - } - - // 校验 - parse_and_validate_cert(x509_rc) - } - - - -} - -fn parse_and_validate_cert(x509_rc: X509Certificate) -> Result { - ///逐个校验RC的内容, 如果有任何一个校验失败, 则返回错误 - - // 1. 版本号必须是V3 - let version = match x509_rc.version() { - X509Version::V3 => X509Version::V3, - v => { - return Err(ResourceCertError::InvalidVersion(v.0)); - } - }; - - // 2.校验签名算法 - // 2.1. 校验外层的签名算法与里层的一致 - let outer = &x509_rc.signature_algorithm; - let inner = &x509_rc.tbs_certificate.signature; - - if outer.algorithm != inner.algorithm || outer.parameters != inner.parameters { - return Err(ResourceCertError::SignatureAlgorithmMismatch); - } - //2.2 RPKI的签名算法必须是rsaWithSHA256 - let signature_algorithm = &x509_rc.signature_algorithm; - if signature_algorithm.algorithm.to_id_string() != oids::OID_SHA256_WITH_RSA_ENCRYPTION { - return Err(ResourceCertError::UnsupportedSignatureAlgorithm); - } - validate_sig_params(signature_algorithm)?; - - // 3. 校验Validity - let validity = x509_rc.validity(); - validate_validity(validity, OffsetDateTime::now_utc())?; - - // 4. SubjectPublicKeyInfo - let subject_public_key_info = x509_rc.tbs_certificate.subject_pki; - - let extensions = parse_and_validate_extensions(x509_rc.extensions())?; - - // TODO - Ok(ResourceCert { - - }) - - -} - -fn validate_sig_params(sig: &AlgorithmIdentifier<'_>) -> Result<(), CrlDecodeError> { - match sig.parameters.as_ref() { - None => Ok(()), - Some(p) if p.tag() == Tag::Null => Ok(()), - Some(_p) => Err(CrlDecodeError::InvalidSignatureAlgorithmParameters), - } -} - -fn validate_validity( - validity: &Validity, - now: OffsetDateTime, -) -> Result<(), ResourceCertError> { - let not_before = validity.not_before.to_datetime(); - let not_after = validity.not_after.to_datetime(); - - if not_after < not_before { - return Err(ResourceCertError::InvalidValidityRange); - } - - if now < not_before { - return Err(ResourceCertError::NotYetValid); - } - - if now > not_after { - return Err(ResourceCertError::Expired); - } - - Ok(()) -} - - -pub fn parse_and_validate_extensions( - exts: &[X509Extension<'_>], -) -> Result { - let mut basic_constraints = None; - let mut ip_addr_blocks = None; - let mut as_identifiers = None; - let mut ski = None; - let mut aki = None; - let mut crl_dp = None; - let mut aia = None; - let mut sia = None; - let mut key_usage = None; - let mut extended_key_usage = None; - let mut certificate_policies = None; - - for ext in exts { - let oid = ext.oid.to_id_string(); - let critical = ext.critical; - match oid.as_str() { - oids::OID_BASIC_CONSTRAINTS => { - if basic_constraints.is_some() { - return Err(ResourceCertError::DuplicateExtension("basicConstraints".into())); - } - if !critical { - return Err(ResourceCertError::CriticalError("basicConstraints".into(), "critical".into())); - } - let bc = parse_basic_constraints(ext)?; - basic_constraints = Some(bc); - } - oids::OID_SUBJECT_KEY_IDENTIFIER => { - if ski.is_some() { - return Err(ResourceCertError::DuplicateExtension("subjectKeyIdentifier".into())); - } - if critical { - return Err(ResourceCertError::CriticalError("subjectKeyIdentifier".into(), "non-critical".into())); - } - let s = parse_subject_key_identifier(ext)?; - ski = Some(s); - } - oids::OID_AUTHORITY_KEY_IDENTIFIER => { - if aki.is_some() { - return Err(ResourceCertError::DuplicateExtension("authorityKeyIdentifier".into())); - } - if critical { - return Err(ResourceCertError::CriticalError("authorityKeyIdentifier".into(), "non-critical".into())); - } - let a = parse_authority_key_identifier(ext)?; - aki = Some(a); - } - oids::OID_KEY_USAGE => { - if key_usage.is_some() { - return Err(ResourceCertError::DuplicateExtension("keyUsage".into())); - } - if !critical { - return Err(ResourceCertError::CriticalError("keyUsage".into(), "critical".into())); - } - let ku = parse_key_usage(ext)?; - key_usage = Some(ku); - } - oids::OID_EXTENDED_KEY_USAGE => { - if extended_key_usage.is_some() { - return Err(ResourceCertError::DuplicateExtension("extendedKeyUsage".into())); - } - if critical { - return Err(ResourceCertError::CriticalError("extendedKeyUsage".into(), "non-critical".into())); - } - let eku = oids::OID_EXTENDED_KEY_USAGE; - } - oids::OID_CRL_DISTRIBUTION_POINTS => { - if crl_dp.is_some() { - return Err(ResourceCertError::DuplicateExtension("crlDistributionPoints".into())); - } - if critical { - return Err(ResourceCertError::CriticalError("crlDistributionPoints".into(), "non-critical".into())); - } - let cdp = parse_crl_distribution_points(ext)?; - crl_dp = Some(cdp); - } - oids::OID_AUTHORITY_INFO_ACCESS => { - if aia.is_some() { - return Err(ResourceCertError::DuplicateExtension("authorityInfoAccess".into())); - } - if critical { - return Err(ResourceCertError::CriticalError("authorityInfoAccess".into(), "non-critical".into())); - } - let p_aia = parse_authority_info_access(ext)?; - aia = Some(p_aia); - } - oids::OID_SUBJECT_INFO_ACCESS => { - if sia.is_some() { - return Err(ResourceCertError::DuplicateExtension("subjectInfoAccess".into())); - } - if critical { - return Err(ResourceCertError::CriticalError("subjectInfoAccess".into(), "non-critical".into())); - } - let p_sia = parse_subject_info_access(ext)?; - sia = Some(p_sia); - } - oids::OID_CERTIFICATE_POLICIES => { - if certificate_policies.is_some() { - return Err(ResourceCertError::DuplicateExtension("certificatePolicies".into())); - } - if !critical { - return Err(ResourceCertError::CriticalError("certificatePolicies".into(), "critical".into())); - } - let p_cp = parse_certificate_policies(ext)?; - certificate_policies = Some(p_cp); - } - oids::OID_IP_ADDRESS_BLOCKS => { - if ip_addr_blocks.is_some() { - return Err(ResourceCertError::DuplicateExtension("ipAddressBlocks".into())); - } - if !critical { - return Err(ResourceCertError::CriticalError("ipAddressBlocks".into(), "critical".into())); - } - let p_ip = parse_ip_address_blocks(ext)?; - ip_addr_blocks = Some(p_ip); - } - oids::OID_AS_IDENTIFIERS => { - if as_identifiers.is_some() { - return Err(ResourceCertError::DuplicateExtension("asIdentifiers".into())); - } - if !critical { - return Err(ResourceCertError::CriticalError("asIdentifiers".into(), "critical".into())); - } - let p_as = parse_as_identifiers(ext)?; - as_identifiers = Some(p_as); - } - } - } - - // TODO: - Ok(RcExtension { - - } -} - -fn parse_basic_constraints(ext: &X509Extension<'_>) -> Result { - let ParsedExtension::BasicConstraints(bc) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert("basicConstraints parse failed".into())); - }; - Ok(bc.ca) -} - -fn parse_subject_key_identifier(ext: &X509Extension<'_>) -> Result, ResourceCertError> { - let ParsedExtension::SubjectKeyIdentifier(s) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert("subjectKeyIdentifier parse failed".into())); - }; - Ok(s.0.to_vec()) -} - -fn parse_authority_key_identifier(ext: &X509Extension<'_>) -> Result, ResourceCertError> { - let ParsedExtension::AuthorityKeyIdentifier(aki) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert("authorityKeyIdentifier parse failed".into())); - }; - let key_id = aki - .key_identifier - .as_ref() - .ok_or(ResourceCertError::MissingParameter("key_identifier".into()))?; - - if aki.authority_cert_issuer.is_some() { - return Err(ResourceCertError::UnexceptedParameter("authority_cert_issuer".into())); - } - if aki.authority_cert_serial.is_some() { - return Err(ResourceCertError::UnexceptedParameter("authority_cert_serial".into())); - } - - - Ok(key_id.0.to_vec()) -} - -fn parse_key_usage(ext: &X509Extension<'_>) -> Result { - let ParsedExtension::KeyUsage(ku) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert("keyUsage parse failed".into())); - }; - Ok(ku.clone()) -} - -fn parse_crl_distribution_points(ext: &X509Extension<'_>) -> Result, ResourceCertError> { - let ParsedExtension::CRLDistributionPoints(cdp) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert("crlDistributionPoints parse failed".into())); - }; - let mut urls = Vec::new(); - for point in cdp.points.iter() { - if point.reasons.is_some() { - return Err(ResourceCertError::UnexceptedParameter("reasons".into())); - } - if point.crl_issuer.is_some() { - return Err(ResourceCertError::UnexceptedParameter("crl_issuer".into())); - } - - let dp_name = point.distribution_point.as_ref() - .ok_or(ResourceCertError::MissingParameter("distribution_point".into()))?; - match dp_name { - DistributionPointName::FullName(names) => { - for name in names { - match name { - GeneralName::URI(uri) => { - let url = Url::parse(uri) - .map_err(|_| ResourceCertError::InvalidUri(uri.to_string()))?; - urls.push(url); - } - _ => { - return Err(ResourceCertError::UnsupportedGeneralName("distribution_point".into())); - } - } - } - - } - DistributionPointName::NameRelativeToCRLIssuer(_) => { - return Err(ResourceCertError::UnsupportedCrlDistributionPoint); - } - } - } - if urls.is_empty() { - return Err(ResourceCertError::MissingParameter("distribution_point".into())); - } - Ok(urls) -} - -fn parse_authority_info_access( - ext: &X509Extension<'_>, -) -> Result, ResourceCertError> { - let ParsedExtension::AuthorityInfoAccess(aia) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert( - "authorityInfoAccess parse failed".into(), - )); - }; - - let mut access_descriptions = Vec::new(); - - for access in &aia.accessdescs { - let access_method_oid = access.access_method.to_id_string(); - - let uri = match &access.access_location { - GeneralName::URI(uri) => uri, - _ => { - return Err(ResourceCertError::InvalidAccessLocationType); - } - }; - - let url = Url::parse(uri) - .map_err(|_| ResourceCertError::InvalidUri(uri.to_string()))?; - - access_descriptions.push(AccessDescription { - access_method_oid, - access_location: url, - }); - } - - if access_descriptions.is_empty() { - return Err(ResourceCertError::EmptyAuthorityInfoAccess); - } - - Ok(access_descriptions) -} - -fn parse_subject_info_access(ext: &X509Extension<'_>) -> Result, ResourceCertError> { - let ParsedExtension::SubjectInfoAccess(sia) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert( - "subjectInfoAccess parse failed".into(), - )); - }; - let mut access_descriptions = Vec::new(); - - for access in &sia.accessdescs { - let access_method_oid = access.access_method.to_id_string(); - - // accessLocation: MUST be URI in RPKI - let uri = match &access.access_location { - GeneralName::URI(uri) => uri, - _ => { - return Err(ResourceCertError::InvalidAccessLocationType); - } - }; - - let url = Url::parse(uri) - .map_err(|_| ResourceCertError::InvalidUri(uri.to_string()))?; - - access_descriptions.push(AccessDescription { - access_method_oid, - access_location: url, - }); - } - - if access_descriptions.is_empty() { - return Err(ResourceCertError::EmptyAuthorityInfoAccess); - } - - Ok(access_descriptions) -} - -fn parse_certificate_policies(ext: &X509Extension<'_>) -> Result, ResourceCertError> { - let ParsedExtension::CertificatePolicies(cp) = ext.parsed_extension() else { - return Err(ResourceCertError::ParseCert( - "certificatePolicies parse failed".into(), - )); - }; - let mut policies = Vec::new(); - if cp.len() > 1 { - return Err(ResourceCertError::CertificatePoliciesTooMany); - } - let policy_info = cp.first().unwrap(); - let policy_id = policy_info.policy_id.to_id_string(); - if policy_id != oids::OID_RPKI_CP { - return Err(ResourceCertError::CertificatePoliciesInvalid); - } - let policy_info = PolicyInformation{ - policy_oid: policy_id, - }; - policies.push(policy_info); - Ok(policies) -} - -fn bitstring_to_ip(b: &BitString, afi: &Afi) -> Result { - let bytes = b.as_bytes(); - let ip = match afi { - Afi::Ipv4 => { - if bytes.len() != 4 { return Err(ResourceCertError::ParseCert("IPv4 length mismatch".into())); } - u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as u128 - }, - Afi::Ipv6 => { - if bytes.len() != 16 { return Err(ResourceCertError::ParseCert("IPv6 length mismatch".into())); } - u128::from_be_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], - bytes[4], bytes[5], bytes[6], bytes[7], - bytes[8], bytes[9], bytes[10], bytes[11], - bytes[12], bytes[13], bytes[14], bytes[15], - ]) - }, - }; - Ok(IPAddress(ip)) -} - -fn parse_ip_address_blocks(ext: &X509Extension<'_>) -> Result { - - let ip_blocks_der = ext.value; - - let ips = parse(ip_blocks_der, |p| { - // 顶层 SEQUENCE OF IPAddressFamily - let mut ip_families = Vec::new(); - p.read_sequence_of(|p2| { - // 每个 IPAddressFamily 是 SEQUENCE { addressFamily OCTET STRING, ipAddressChoice } - p2.read_sequence(|p3| { - let address_family_bytes = p3.read_element::<&[u8]>()?; - let afi = match address_family_bytes { - [0,1] => Afi::Ipv4, - [0,2] => Afi::Ipv6, - _ => return Err(asn1::ParseError::new(asn1::ParseErrorKind::InvalidValue)), - }; - - // 解析 IPAddressChoice - let ip_address_choice = { - let peek_tag = p3.peek_tag()?; - match peek_tag.tag_number() { - 5 => IPAddressChoice::Inherit, // NULL - 16 => { // SEQUENCE OF IPAddressOrRange - let ranges = p3.read_sequence_of(|p4| { - // 解析 IPAddressOrRange CHOICE - let peek = p4.peek_tag()?.tag_number(); - if peek == 16 { // SEQUENCE -> AddressPrefix - let (addr_bytes, prefix_len): (&[u8], u8) = p4.read_sequence(|p5| { - let addr = p5.read_element::()?.as_bytes(); - let prefix_len = addr.len() as u8 * 8; // 简化:用字节长度推前缀 - Ok((addr, prefix_len)) - })?; - Ok(IPAddressOrRange::AddressPrefix(IPAddressPrefix{ - address: bitstring_to_ip(addr_bytes, &afi)?, - prefix_length: prefix_len, - })) - } else { - // AddressRange - let (min_bytes, max_bytes) = p4.read_sequence(|p5| { - let min = p5.read_element::()?.as_bytes(); - let max = p5.read_element::()?.as_bytes(); - Ok((min, max)) - })?; - Ok(IPAddressOrRange::AddressRange(IPAddressRange{ - min: bitstring_to_ip(min_bytes, &afi)?, - max: bitstring_to_ip(max_bytes, &afi)?, - })) - } - })?; - IPAddressChoice::AddressOrRange(ranges) - } - _ => return Err(asn1::ParseError::new(asn1::ParseErrorKind::InvalidValue)), - } - }; - - ip_families.push(IPAddressFamily { - address_family: afi, - ip_address_choice, - }); - - Ok(()) - }) - })?; - Ok(IPAddrBlocks { ips: ip_families }) - }).map_err(|_| ResourceCertError::ParseCert("Failed to parse IPAddrBlocks DER".into()))?; - - Ok(ips) -} - -fn parse_as_identifiers(ext: &X509Extension<'_>) -> Result, ResourceCertError> { - let as_identifiers_der = ext.value; - // TODO: 解析 ASIdentifiers DER - -} \ No newline at end of file +// use asn1::{parse, BitString}; +// use der_parser::asn1_rs::Tag; +// use der_parser::num_bigint::BigUint; +// use url::Url; +// use time::OffsetDateTime; +// use x509_parser::x509::AlgorithmIdentifier; +// use x509_parser::prelude::{Validity, KeyUsage, X509Certificate, FromDer, +// X509Version, X509Extension, ParsedExtension, +// DistributionPointName, GeneralName}; +// use crate::data_model::crl::CrlDecodeError; +// use crate::data_model::resources::ip_resources::{Afi, IPAddrBlocks, IPAddress, IPAddressChoice, +// IPAddressOrRange, IPAddressPrefix, IPAddressRange, +// IPAddressFamily}; +// use crate::data_model::resources::as_resources::ASIdentifiers; +// use crate::data_model::oids; +// +// +// #[derive(Clone, Debug, PartialEq, Eq)] +// pub struct SubjectPublicKeyInfo { +// pub algorithm_oid: String, +// pub subject_public_key: u8, +// } +// +// #[derive(Clone, Debug, PartialEq, Eq)] +// pub struct AccessDescription { +// pub access_method_oid: String, +// pub access_location: Url, +// } +// +// #[derive(Clone, Debug, PartialEq, Eq)] +// pub struct PolicyInformation { +// pub policy_oid: String, +// } +// +// #[derive(Clone, Debug, PartialEq, Eq)] +// pub struct RcExtension { +// pub basic_constraints: bool, +// pub subject_key_identifier: u8, +// pub authority_key_identifier: u8, +// pub key_usage: KeyUsage, +// pub extended_key_usage_oid: u8, +// pub crl_distribution_points: Vec, +// pub authority_info_access: Vec, +// pub subject_info_access: Vec, +// pub certificate_policies: Vec, +// pub ip_resource: IPAddrBlocks, +// pub as_resource: ASIdentifiers, +// +// } +// +// #[derive(Clone, Debug, PartialEq, Eq)] +// pub struct ResourceCert { +// /// 证书原始DER内容 +// pub cert_der: Vec, +// +// /// 基本证书信息 +// pub version: u32, +// pub serial_number: BigUint, +// pub signature_algorithm_oid: String, +// pub issuer_dn: String, +// pub subject_dn: String, +// pub validity: Validity, +// pub subject_public_key_info: SubjectPublicKeyInfo, +// pub extensions: RcExtension, +// } +// +// +// #[derive(Debug, thiserror::Error)] +// pub enum ResourceCertError { +// #[error("X.509 parse resource cert error: {0}")] +// ParseCert(String), +// +// #[error("trailing bytes after CRL DER: {0} bytes")] +// TrailingBytes(usize), +// +// #[error("invalid version {0}")] +// InvalidVersion(u32), +// +// #[error("signatureAlgorithm does not match tbsCertificate.signature")] +// SignatureAlgorithmMismatch, +// +// #[error("unsupported signature algorithm")] +// UnsupportedSignatureAlgorithm, +// +// #[error("invalid Cert signature algorithm parameters")] +// InvalidSignatureParameters, +// +// #[error("invalid Cert validity range")] +// InvalidValidityRange, +// +// #[error("Cert not yet valid")] +// NotYetValid, +// +// #[error("expired")] +// Expired, +// +// #[error("Critical error, {0} should be {1}")] +// CriticalError(String, String), +// +// #[error("Duplicate Extension: {0}")] +// DuplicateExtension(String), +// +// #[error("AKI missing keyIdentifier")] +// AkiMissingKeyIdentifier, +// +// #[error("Unexpected parameter: {0}")] +// UnexceptedParameter(String), +// +// #[error("Missing parameter: {0}")] +// MissingParameter(String), +// +// #[error("CRL DP invalid distributionPointName: {0}")] +// CrlDpInvalidDistributionPointName(String), +// +// #[error("CRL DP unexpected distributionPointType: {0}")] +// CrlDpUnexpectedDistributionPointType(String), +// +// #[error("invalid URI: {0}")] +// InvalidUri(String), +// +// #[error("Unsupported General Name in {0}")] +// UnsupportedGeneralName(String), +// +// #[error("Unsupported CRL Distribution Point")] +// UnsupportedCrlDistributionPoint, +// +// #[error("Invalid Access Location Type")] +// InvalidAccessLocationType, +// +// #[error("Empty AuthorityInfoAccess!")] +// EmptyAuthorityInfoAccess, +// +// #[error("Certificate Policies must exists one policy")] +// CertificatePoliciesTooMany, +// +// #[error("Certificate Policies invalid")] +// CertificatePoliciesInvalid, +// +// } +// +// impl ResourceCert{ +// pub fn from_der(cert_der: &[u8]) -> Result { +// let (rem, x509_rc) = X509Certificate::from_der(cert_der) +// .map_err(|e| ResourceCertError::ParseCert(e.to_string()))?; +// +// if !rem.is_empty() { +// return Err(ResourceCertError::TrailingBytes(rem.len())); +// } +// +// // 校验 +// parse_and_validate_cert(x509_rc) +// } +// +// +// +// } +// +// fn parse_and_validate_cert(x509_rc: X509Certificate) -> Result { +// ///逐个校验RC的内容, 如果有任何一个校验失败, 则返回错误 +// +// // 1. 版本号必须是V3 +// let version = match x509_rc.version() { +// X509Version::V3 => X509Version::V3, +// v => { +// return Err(ResourceCertError::InvalidVersion(v.0)); +// } +// }; +// +// // 2.校验签名算法 +// // 2.1. 校验外层的签名算法与里层的一致 +// let outer = &x509_rc.signature_algorithm; +// let inner = &x509_rc.tbs_certificate.signature; +// +// if outer.algorithm != inner.algorithm || outer.parameters != inner.parameters { +// return Err(ResourceCertError::SignatureAlgorithmMismatch); +// } +// //2.2 RPKI的签名算法必须是rsaWithSHA256 +// let signature_algorithm = &x509_rc.signature_algorithm; +// if signature_algorithm.algorithm.to_id_string() != oids::OID_SHA256_WITH_RSA_ENCRYPTION { +// return Err(ResourceCertError::UnsupportedSignatureAlgorithm); +// } +// validate_sig_params(signature_algorithm)?; +// +// // 3. 校验Validity +// let validity = x509_rc.validity(); +// validate_validity(validity, OffsetDateTime::now_utc())?; +// +// // 4. SubjectPublicKeyInfo +// let subject_public_key_info = x509_rc.tbs_certificate.subject_pki; +// +// let extensions = parse_and_validate_extensions(x509_rc.extensions())?; +// +// // TODO +// Ok(ResourceCert { +// +// }) +// +// +// } +// +// fn validate_sig_params(sig: &AlgorithmIdentifier<'_>) -> Result<(), CrlDecodeError> { +// match sig.parameters.as_ref() { +// None => Ok(()), +// Some(p) if p.tag() == Tag::Null => Ok(()), +// Some(_p) => Err(CrlDecodeError::InvalidSignatureAlgorithmParameters), +// } +// } +// +// fn validate_validity( +// validity: &Validity, +// now: OffsetDateTime, +// ) -> Result<(), ResourceCertError> { +// let not_before = validity.not_before.to_datetime(); +// let not_after = validity.not_after.to_datetime(); +// +// if not_after < not_before { +// return Err(ResourceCertError::InvalidValidityRange); +// } +// +// if now < not_before { +// return Err(ResourceCertError::NotYetValid); +// } +// +// if now > not_after { +// return Err(ResourceCertError::Expired); +// } +// +// Ok(()) +// } +// +// +// pub fn parse_and_validate_extensions( +// exts: &[X509Extension<'_>], +// ) -> Result { +// let mut basic_constraints = None; +// let mut ip_addr_blocks = None; +// let mut as_identifiers = None; +// let mut ski = None; +// let mut aki = None; +// let mut crl_dp = None; +// let mut aia = None; +// let mut sia = None; +// let mut key_usage = None; +// let mut extended_key_usage = None; +// let mut certificate_policies = None; +// +// for ext in exts { +// let oid = ext.oid.to_id_string(); +// let critical = ext.critical; +// match oid.as_str() { +// oids::OID_BASIC_CONSTRAINTS => { +// if basic_constraints.is_some() { +// return Err(ResourceCertError::DuplicateExtension("basicConstraints".into())); +// } +// if !critical { +// return Err(ResourceCertError::CriticalError("basicConstraints".into(), "critical".into())); +// } +// let bc = parse_basic_constraints(ext)?; +// basic_constraints = Some(bc); +// } +// oids::OID_SUBJECT_KEY_IDENTIFIER => { +// if ski.is_some() { +// return Err(ResourceCertError::DuplicateExtension("subjectKeyIdentifier".into())); +// } +// if critical { +// return Err(ResourceCertError::CriticalError("subjectKeyIdentifier".into(), "non-critical".into())); +// } +// let s = parse_subject_key_identifier(ext)?; +// ski = Some(s); +// } +// oids::OID_AUTHORITY_KEY_IDENTIFIER => { +// if aki.is_some() { +// return Err(ResourceCertError::DuplicateExtension("authorityKeyIdentifier".into())); +// } +// if critical { +// return Err(ResourceCertError::CriticalError("authorityKeyIdentifier".into(), "non-critical".into())); +// } +// let a = parse_authority_key_identifier(ext)?; +// aki = Some(a); +// } +// oids::OID_KEY_USAGE => { +// if key_usage.is_some() { +// return Err(ResourceCertError::DuplicateExtension("keyUsage".into())); +// } +// if !critical { +// return Err(ResourceCertError::CriticalError("keyUsage".into(), "critical".into())); +// } +// let ku = parse_key_usage(ext)?; +// key_usage = Some(ku); +// } +// oids::OID_EXTENDED_KEY_USAGE => { +// if extended_key_usage.is_some() { +// return Err(ResourceCertError::DuplicateExtension("extendedKeyUsage".into())); +// } +// if critical { +// return Err(ResourceCertError::CriticalError("extendedKeyUsage".into(), "non-critical".into())); +// } +// let eku = oids::OID_EXTENDED_KEY_USAGE; +// } +// oids::OID_CRL_DISTRIBUTION_POINTS => { +// if crl_dp.is_some() { +// return Err(ResourceCertError::DuplicateExtension("crlDistributionPoints".into())); +// } +// if critical { +// return Err(ResourceCertError::CriticalError("crlDistributionPoints".into(), "non-critical".into())); +// } +// let cdp = parse_crl_distribution_points(ext)?; +// crl_dp = Some(cdp); +// } +// oids::OID_AUTHORITY_INFO_ACCESS => { +// if aia.is_some() { +// return Err(ResourceCertError::DuplicateExtension("authorityInfoAccess".into())); +// } +// if critical { +// return Err(ResourceCertError::CriticalError("authorityInfoAccess".into(), "non-critical".into())); +// } +// let p_aia = parse_authority_info_access(ext)?; +// aia = Some(p_aia); +// } +// oids::OID_SUBJECT_INFO_ACCESS => { +// if sia.is_some() { +// return Err(ResourceCertError::DuplicateExtension("subjectInfoAccess".into())); +// } +// if critical { +// return Err(ResourceCertError::CriticalError("subjectInfoAccess".into(), "non-critical".into())); +// } +// let p_sia = parse_subject_info_access(ext)?; +// sia = Some(p_sia); +// } +// oids::OID_CERTIFICATE_POLICIES => { +// if certificate_policies.is_some() { +// return Err(ResourceCertError::DuplicateExtension("certificatePolicies".into())); +// } +// if !critical { +// return Err(ResourceCertError::CriticalError("certificatePolicies".into(), "critical".into())); +// } +// let p_cp = parse_certificate_policies(ext)?; +// certificate_policies = Some(p_cp); +// } +// oids::OID_IP_ADDRESS_BLOCKS => { +// if ip_addr_blocks.is_some() { +// return Err(ResourceCertError::DuplicateExtension("ipAddressBlocks".into())); +// } +// if !critical { +// return Err(ResourceCertError::CriticalError("ipAddressBlocks".into(), "critical".into())); +// } +// let p_ip = parse_ip_address_blocks(ext)?; +// ip_addr_blocks = Some(p_ip); +// } +// oids::OID_AS_IDENTIFIERS => { +// if as_identifiers.is_some() { +// return Err(ResourceCertError::DuplicateExtension("asIdentifiers".into())); +// } +// if !critical { +// return Err(ResourceCertError::CriticalError("asIdentifiers".into(), "critical".into())); +// } +// let p_as = parse_as_identifiers(ext)?; +// as_identifiers = Some(p_as); +// } +// } +// } +// +// // TODO: +// Ok(RcExtension { +// +// } +// } +// +// fn parse_basic_constraints(ext: &X509Extension<'_>) -> Result { +// let ParsedExtension::BasicConstraints(bc) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert("basicConstraints parse failed".into())); +// }; +// Ok(bc.ca) +// } +// +// fn parse_subject_key_identifier(ext: &X509Extension<'_>) -> Result, ResourceCertError> { +// let ParsedExtension::SubjectKeyIdentifier(s) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert("subjectKeyIdentifier parse failed".into())); +// }; +// Ok(s.0.to_vec()) +// } +// +// fn parse_authority_key_identifier(ext: &X509Extension<'_>) -> Result, ResourceCertError> { +// let ParsedExtension::AuthorityKeyIdentifier(aki) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert("authorityKeyIdentifier parse failed".into())); +// }; +// let key_id = aki +// .key_identifier +// .as_ref() +// .ok_or(ResourceCertError::MissingParameter("key_identifier".into()))?; +// +// if aki.authority_cert_issuer.is_some() { +// return Err(ResourceCertError::UnexceptedParameter("authority_cert_issuer".into())); +// } +// if aki.authority_cert_serial.is_some() { +// return Err(ResourceCertError::UnexceptedParameter("authority_cert_serial".into())); +// } +// +// +// Ok(key_id.0.to_vec()) +// } +// +// fn parse_key_usage(ext: &X509Extension<'_>) -> Result { +// let ParsedExtension::KeyUsage(ku) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert("keyUsage parse failed".into())); +// }; +// Ok(ku.clone()) +// } +// +// fn parse_crl_distribution_points(ext: &X509Extension<'_>) -> Result, ResourceCertError> { +// let ParsedExtension::CRLDistributionPoints(cdp) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert("crlDistributionPoints parse failed".into())); +// }; +// let mut urls = Vec::new(); +// for point in cdp.points.iter() { +// if point.reasons.is_some() { +// return Err(ResourceCertError::UnexceptedParameter("reasons".into())); +// } +// if point.crl_issuer.is_some() { +// return Err(ResourceCertError::UnexceptedParameter("crl_issuer".into())); +// } +// +// let dp_name = point.distribution_point.as_ref() +// .ok_or(ResourceCertError::MissingParameter("distribution_point".into()))?; +// match dp_name { +// DistributionPointName::FullName(names) => { +// for name in names { +// match name { +// GeneralName::URI(uri) => { +// let url = Url::parse(uri) +// .map_err(|_| ResourceCertError::InvalidUri(uri.to_string()))?; +// urls.push(url); +// } +// _ => { +// return Err(ResourceCertError::UnsupportedGeneralName("distribution_point".into())); +// } +// } +// } +// +// } +// DistributionPointName::NameRelativeToCRLIssuer(_) => { +// return Err(ResourceCertError::UnsupportedCrlDistributionPoint); +// } +// } +// } +// if urls.is_empty() { +// return Err(ResourceCertError::MissingParameter("distribution_point".into())); +// } +// Ok(urls) +// } +// +// fn parse_authority_info_access( +// ext: &X509Extension<'_>, +// ) -> Result, ResourceCertError> { +// let ParsedExtension::AuthorityInfoAccess(aia) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert( +// "authorityInfoAccess parse failed".into(), +// )); +// }; +// +// let mut access_descriptions = Vec::new(); +// +// for access in &aia.accessdescs { +// let access_method_oid = access.access_method.to_id_string(); +// +// let uri = match &access.access_location { +// GeneralName::URI(uri) => uri, +// _ => { +// return Err(ResourceCertError::InvalidAccessLocationType); +// } +// }; +// +// let url = Url::parse(uri) +// .map_err(|_| ResourceCertError::InvalidUri(uri.to_string()))?; +// +// access_descriptions.push(AccessDescription { +// access_method_oid, +// access_location: url, +// }); +// } +// +// if access_descriptions.is_empty() { +// return Err(ResourceCertError::EmptyAuthorityInfoAccess); +// } +// +// Ok(access_descriptions) +// } +// +// fn parse_subject_info_access(ext: &X509Extension<'_>) -> Result, ResourceCertError> { +// let ParsedExtension::SubjectInfoAccess(sia) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert( +// "subjectInfoAccess parse failed".into(), +// )); +// }; +// let mut access_descriptions = Vec::new(); +// +// for access in &sia.accessdescs { +// let access_method_oid = access.access_method.to_id_string(); +// +// // accessLocation: MUST be URI in RPKI +// let uri = match &access.access_location { +// GeneralName::URI(uri) => uri, +// _ => { +// return Err(ResourceCertError::InvalidAccessLocationType); +// } +// }; +// +// let url = Url::parse(uri) +// .map_err(|_| ResourceCertError::InvalidUri(uri.to_string()))?; +// +// access_descriptions.push(AccessDescription { +// access_method_oid, +// access_location: url, +// }); +// } +// +// if access_descriptions.is_empty() { +// return Err(ResourceCertError::EmptyAuthorityInfoAccess); +// } +// +// Ok(access_descriptions) +// } +// +// fn parse_certificate_policies(ext: &X509Extension<'_>) -> Result, ResourceCertError> { +// let ParsedExtension::CertificatePolicies(cp) = ext.parsed_extension() else { +// return Err(ResourceCertError::ParseCert( +// "certificatePolicies parse failed".into(), +// )); +// }; +// let mut policies = Vec::new(); +// if cp.len() > 1 { +// return Err(ResourceCertError::CertificatePoliciesTooMany); +// } +// let policy_info = cp.first().unwrap(); +// let policy_id = policy_info.policy_id.to_id_string(); +// if policy_id != oids::OID_RPKI_CP { +// return Err(ResourceCertError::CertificatePoliciesInvalid); +// } +// let policy_info = PolicyInformation{ +// policy_oid: policy_id, +// }; +// policies.push(policy_info); +// Ok(policies) +// } +// +// fn bitstring_to_ip(b: &BitString, afi: &Afi) -> Result { +// let bytes = b.as_bytes(); +// let ip = match afi { +// Afi::Ipv4 => { +// if bytes.len() != 4 { return Err(ResourceCertError::ParseCert("IPv4 length mismatch".into())); } +// u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as u128 +// }, +// Afi::Ipv6 => { +// if bytes.len() != 16 { return Err(ResourceCertError::ParseCert("IPv6 length mismatch".into())); } +// u128::from_be_bytes([ +// bytes[0], bytes[1], bytes[2], bytes[3], +// bytes[4], bytes[5], bytes[6], bytes[7], +// bytes[8], bytes[9], bytes[10], bytes[11], +// bytes[12], bytes[13], bytes[14], bytes[15], +// ]) +// }, +// }; +// Ok(IPAddress(ip)) +// } +// +// fn parse_ip_address_blocks(ext: &X509Extension<'_>) -> Result { +// +// let ip_blocks_der = ext.value; +// +// let ips = parse(ip_blocks_der, |p| { +// // 顶层 SEQUENCE OF IPAddressFamily +// let mut ip_families = Vec::new(); +// p.read_sequence_of(|p2| { +// // 每个 IPAddressFamily 是 SEQUENCE { addressFamily OCTET STRING, ipAddressChoice } +// p2.read_sequence(|p3| { +// let address_family_bytes = p3.read_element::<&[u8]>()?; +// let afi = match address_family_bytes { +// [0,1] => Afi::Ipv4, +// [0,2] => Afi::Ipv6, +// _ => return Err(asn1::ParseError::new(asn1::ParseErrorKind::InvalidValue)), +// }; +// +// // 解析 IPAddressChoice +// let ip_address_choice = { +// let peek_tag = p3.peek_tag()?; +// match peek_tag.tag_number() { +// 5 => IPAddressChoice::Inherit, // NULL +// 16 => { // SEQUENCE OF IPAddressOrRange +// let ranges = p3.read_sequence_of(|p4| { +// // 解析 IPAddressOrRange CHOICE +// let peek = p4.peek_tag()?.tag_number(); +// if peek == 16 { // SEQUENCE -> AddressPrefix +// let (addr_bytes, prefix_len): (&[u8], u8) = p4.read_sequence(|p5| { +// let addr = p5.read_element::()?.as_bytes(); +// let prefix_len = addr.len() as u8 * 8; // 简化:用字节长度推前缀 +// Ok((addr, prefix_len)) +// })?; +// Ok(IPAddressOrRange::AddressPrefix(IPAddressPrefix{ +// address: bitstring_to_ip(addr_bytes, &afi)?, +// prefix_length: prefix_len, +// })) +// } else { +// // AddressRange +// let (min_bytes, max_bytes) = p4.read_sequence(|p5| { +// let min = p5.read_element::()?.as_bytes(); +// let max = p5.read_element::()?.as_bytes(); +// Ok((min, max)) +// })?; +// Ok(IPAddressOrRange::AddressRange(IPAddressRange{ +// min: bitstring_to_ip(min_bytes, &afi)?, +// max: bitstring_to_ip(max_bytes, &afi)?, +// })) +// } +// })?; +// IPAddressChoice::AddressOrRange(ranges) +// } +// _ => return Err(asn1::ParseError::new(asn1::ParseErrorKind::InvalidValue)), +// } +// }; +// +// ip_families.push(IPAddressFamily { +// address_family: afi, +// ip_address_choice, +// }); +// +// Ok(()) +// }) +// })?; +// Ok(IPAddrBlocks { ips: ip_families }) +// }).map_err(|_| ResourceCertError::ParseCert("Failed to parse IPAddrBlocks DER".into()))?; +// +// Ok(ips) +// } +// +// fn parse_as_identifiers(ext: &X509Extension<'_>) -> Result, ResourceCertError> { +// let as_identifiers_der = ext.value; +// // TODO: 解析 ASIdentifiers DER +// +// } \ No newline at end of file diff --git a/src/data_model/resources/as_resources.rs b/src/data_model/resources/as_resources.rs index 13eb095..fc49f4d 100644 --- a/src/data_model/resources/as_resources.rs +++ b/src/data_model/resources/as_resources.rs @@ -54,7 +54,7 @@ impl ASRange { } } -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Default)] pub struct Asn(u32); impl Asn { diff --git a/src/data_model/resources/ip_resources.rs b/src/data_model/resources/ip_resources.rs index 1e754bc..efc0250 100644 --- a/src/data_model/resources/ip_resources.rs +++ b/src/data_model/resources/ip_resources.rs @@ -42,6 +42,30 @@ pub struct IPAddressRange { pub max: IPAddress, } +use std::net::{Ipv4Addr, Ipv6Addr}; + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct IPAddress(u128); +impl IPAddress { + pub fn to_ipv4(self) -> Option { + if self.0 <= u32::MAX as u128 { + Some(Ipv4Addr::from(self.0 as u32)) + } else { + None + } + } + + pub fn to_ipv6(self) -> Ipv6Addr { + Ipv6Addr::from(self.0) + } + + pub fn is_ipv4(self) -> bool { + self.0 <= u32::MAX as u128 + } + + pub fn as_u128(self) -> u128 { + self.0 + } +} + diff --git a/src/lib.rs b/src/lib.rs index c0dd03b..72d8e7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1 +1,3 @@ pub mod data_model; +mod slurm; +mod rtr; diff --git a/src/rtr/cache.rs b/src/rtr/cache.rs new file mode 100644 index 0000000..29619b7 --- /dev/null +++ b/src/rtr/cache.rs @@ -0,0 +1,538 @@ +use std::collections::{BTreeSet, VecDeque}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use chrono::{DateTime, NaiveDateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey}; +use crate::rtr::store_db::RtrStore; + +const DEFAULT_RETRY_INTERVAL: Duration = Duration::from_secs(600); +const DEFAULT_EXPIRE_INTERVAL: Duration = Duration::from_secs(7200); + +#[derive(Debug, Clone)] +pub struct DualTime { + instant: Instant, + utc: DateTime, +} + +impl DualTime { + /// Create current time. + pub fn now() -> Self { + Self { + instant: Instant::now(), + utc: Utc::now(), + } + } + + /// Get UTC time for logs. + pub fn utc(&self) -> DateTime { + self.utc + } + + /// Elapsed duration since creation/reset. + pub fn elapsed(&self) -> Duration { + self.instant.elapsed() + } + + /// Whether duration is expired. + pub fn is_expired(&self, duration: Duration) -> bool { + self.elapsed() >= duration + } + + /// Reset to now. + pub fn reset(&mut self) { + self.instant = Instant::now(); + self.utc = Utc::now(); + } +} + +impl Serialize for DualTime { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.utc.timestamp_millis().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for DualTime { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let millis = i64::deserialize(deserializer)?; + let naive = NaiveDateTime::from_timestamp_millis(millis) + .ok_or_else(|| serde::de::Error::custom("invalid timestamp"))?; + let utc = DateTime::::from_utc(naive, Utc); + Ok(Self { + instant: Instant::now(), + utc, + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Snapshot { + origins: BTreeSet, + router_keys: BTreeSet, + aspas: BTreeSet, + created_at: DualTime, +} + +impl Snapshot { + pub fn new( + origins: BTreeSet, + router_keys: BTreeSet, + aspas: BTreeSet, + ) -> Self { + Snapshot { + origins, + router_keys, + aspas, + created_at: DualTime::now(), + } + } + + pub fn empty() -> Self { + Self::new(BTreeSet::new(), BTreeSet::new(), BTreeSet::new()) + } + + pub fn from_payloads(payloads: Vec) -> Self { + let mut origins = BTreeSet::new(); + let mut router_keys = BTreeSet::new(); + let mut aspas = BTreeSet::new(); + + for p in payloads { + match p { + Payload::RouteOrigin(o) => { + origins.insert(o); + } + Payload::RouterKey(k) => { + router_keys.insert(k); + } + Payload::Aspa(a) => { + aspas.insert(a); + } + } + } + + Snapshot { + origins, + router_keys, + aspas, + created_at: DualTime::now(), + } + } + + pub fn diff(&self, new_snapshot: &Snapshot) -> (Vec, Vec) { + let mut announced = Vec::new(); + let mut withdrawn = Vec::new(); + + for origin in new_snapshot.origins.difference(&self.origins) { + announced.push(Payload::RouteOrigin(origin.clone())); + } + for origin in self.origins.difference(&new_snapshot.origins) { + withdrawn.push(Payload::RouteOrigin(origin.clone())); + } + + for key in new_snapshot.router_keys.difference(&self.router_keys) { + announced.push(Payload::RouterKey(key.clone())); + } + for key in self.router_keys.difference(&new_snapshot.router_keys) { + withdrawn.push(Payload::RouterKey(key.clone())); + } + + for aspa in new_snapshot.aspas.difference(&self.aspas) { + announced.push(Payload::Aspa(aspa.clone())); + } + for aspa in self.aspas.difference(&new_snapshot.aspas) { + withdrawn.push(Payload::Aspa(aspa.clone())); + } + + (announced, withdrawn) + } + + pub fn created_at(&self) -> DualTime { + self.created_at.clone() + } + + pub fn payloads(&self) -> Vec { + let mut v = Vec::with_capacity( + self.origins.len() + self.router_keys.len() + self.aspas.len(), + ); + + v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin)); + v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey)); + v.extend(self.aspas.iter().cloned().map(Payload::Aspa)); + + v + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Delta { + serial: u32, + announced: Vec, + withdrawn: Vec, + created_at: DualTime, +} + +impl Delta { + pub fn new(serial: u32, announced: Vec, withdrawn: Vec) -> Self { + Delta { + serial, + announced, + withdrawn, + created_at: DualTime::now(), + } + } + + pub fn serial(&self) -> u32 { + self.serial + } + + pub fn announced(&self) -> &[Payload] { + &self.announced + } + + pub fn withdrawn(&self) -> &[Payload] { + &self.withdrawn + } + + pub fn created_at(self) -> DualTime { + self.created_at + } +} + +#[derive(Debug)] +pub struct RtrCache { + // Session ID created at cache startup. + session_id: u16, + // Current serial. + pub serial: u32, + // Full snapshot. + pub snapshot: Snapshot, + // Delta window. + deltas: VecDeque>, + // Max number of deltas to keep. + max_delta: u8, + // Refresh interval. + refresh_interval: Duration, + // Last update begin time. + last_update_begin: DualTime, + // Last update end time. + last_update_end: DualTime, + // Cache created time. + created_at: DualTime, +} + +impl Default for RtrCache { + fn default() -> Self { + let now = DualTime::now(); + Self { + session_id: rand::random(), + serial: 0, + snapshot: Snapshot::empty(), + deltas: VecDeque::with_capacity(100), + max_delta: 100, + refresh_interval: Duration::from_secs(600), + last_update_begin: now.clone(), + last_update_end: now.clone(), + created_at: now, + } + } +} + +pub struct RtrCacheBuilder { + session_id: Option, + max_delta: Option, + refresh_interval: Option, + serial: Option, + snapshot: Option, + created_at: Option, +} + +impl RtrCacheBuilder { + pub fn new() -> Self { + Self { + session_id: None, + max_delta: None, + refresh_interval: None, + serial: None, + snapshot: None, + created_at: None, + } + } + + pub fn session_id(mut self, v: u16) -> Self { + self.session_id = Some(v); + self + } + + pub fn max_delta(mut self, v: u8) -> Self { + self.max_delta = Some(v); + self + } + + pub fn refresh_interval(mut self, v: Duration) -> Self { + self.refresh_interval = Some(v); + self + } + + pub fn serial(mut self, v: u32) -> Self { + self.serial = Some(v); + self + } + + pub fn snapshot(mut self, v: Snapshot) -> Self { + self.snapshot = Some(v); + self + } + + pub fn created_at(mut self, v: DualTime) -> Self { + self.created_at = Some(v); + self + } + + pub fn build(self) -> RtrCache { + let now = DualTime::now(); + let max_delta = self.max_delta.unwrap_or(100); + let refresh_interval = self.refresh_interval.unwrap_or(Duration::from_secs(600)); + let snapshot = self.snapshot.unwrap_or_else(Snapshot::empty); + let serial = self.serial.unwrap_or(0); + let created_at = self.created_at.unwrap_or_else(|| now.clone()); + let session_id = self.session_id.unwrap_or_else(rand::random); + + RtrCache { + session_id, + serial, + snapshot, + deltas: VecDeque::with_capacity(max_delta.into()), + max_delta, + refresh_interval, + last_update_begin: now.clone(), + last_update_end: now, + created_at, + } + } +} + +impl RtrCache { + /// Initialize cache from DB if possible; otherwise from file loader. + pub fn init( + self, + store: &RtrStore, + max_delta: u8, + refresh_interval: Duration, + file_loader: impl Fn() -> anyhow::Result>, + ) -> anyhow::Result { + let snapshot = store.get_snapshot()?; + let session_id = store.get_session_id()?; + let serial = store.get_serial()?; + + if let (Some(snapshot), Some(session_id), Some(serial)) = + (snapshot, session_id, serial) + { + let mut cache = RtrCacheBuilder::new() + .session_id(session_id) + .max_delta(max_delta) + .refresh_interval(refresh_interval) + .serial(serial) + .snapshot(snapshot) + .build(); + + if let Some((min_serial, _max_serial)) = store.get_delta_window()? { + let deltas = store.load_deltas_since(min_serial.wrapping_sub(1))?; + for delta in deltas { + cache.push_delta(Arc::new(delta)); + } + } + + return Ok(cache); + } + + let payloads = file_loader()?; + let snapshot = Snapshot::from_payloads(payloads); + let serial = 1; + let session_id: u16 = rand::random(); + + let store = store.clone(); + tokio::spawn(async move { + if let Err(e) = store.save_snapshot_and_meta(&snapshot, session_id, serial) { + tracing::error!("persist failed: {:?}", e); + } + }); + + Ok(RtrCacheBuilder::new() + .session_id(session_id) + .max_delta(max_delta) + .refresh_interval(refresh_interval) + .serial(serial) + .snapshot(snapshot) + .build()) + } + + fn next_serial(&mut self) -> u32 { + self.serial = self.serial.wrapping_add(1); + self.serial + } + + fn push_delta(&mut self, delta: Arc) { + if self.deltas.len() >= self.max_delta as usize { + self.deltas.pop_front(); + } + self.deltas.push_back(delta); + } + + fn replace_snapshot(&mut self, snapshot: Snapshot) { + self.snapshot = snapshot; + } + + fn delta_window(&self) -> Option<(u32, u32)> { + let min = self.deltas.front().map(|d| d.serial()); + let max = self.deltas.back().map(|d| d.serial()); + match (min, max) { + (Some(min), Some(max)) => Some((min, max)), + _ => None, + } + } + + fn store_sync( + &mut self, + store: &RtrStore, + snapshot: Snapshot, + serial: u32, + session_id: u16, + delta: Arc, + ) { + let window = self.delta_window(); + let store = store.clone(); + + tokio::spawn(async move { + if let Err(e) = store.save_delta(&delta) { + tracing::error!("persist delta failed: {:?}", e); + } + if let Err(e) = store.save_snapshot_and_meta(&snapshot, session_id, serial) { + tracing::error!("persist snapshot/meta failed: {:?}", e); + } + if let Some((min_serial, max_serial)) = window { + if let Err(e) = store.set_delta_window(min_serial, max_serial) { + tracing::error!("persist delta window failed: {:?}", e); + } + } + }); + } + + // Update cache. + pub fn update( + &mut self, + new_payloads: Vec, + store: &RtrStore, + ) -> anyhow::Result<()> { + let new_snapshot = Snapshot::from_payloads(new_payloads); + let (announced, withdrawn) = self.snapshot.diff(&new_snapshot); + + if announced.is_empty() && withdrawn.is_empty() { + return Ok(()); + } + + let new_serial = self.next_serial(); + let delta = Arc::new(Delta::new(new_serial, announced, withdrawn)); + + self.push_delta(delta.clone()); + self.replace_snapshot(new_snapshot.clone()); + self.last_update_end = DualTime::now(); + + self.store_sync(store, new_snapshot, new_serial, self.session_id, delta); + + Ok(()) + } + + pub fn session_id(&self) -> u16 { + self.session_id + } + + pub fn snapshot(&self) -> Snapshot { + self.snapshot.clone() + } + + pub fn serial(&self) -> u32 { + self.serial + } + + pub fn refresh_interval(&self) -> Duration { + self.refresh_interval + } + + pub fn retry_interval(&self) -> Duration { + DEFAULT_RETRY_INTERVAL + } + + pub fn expire_interval(&self) -> Duration { + DEFAULT_EXPIRE_INTERVAL + } + + pub fn current_snapshot(&self) -> (&Snapshot, u32, u16) { + (&self.snapshot, self.serial, self.session_id) + } +} + +impl RtrCache { + pub fn get_deltas_since( + &self, + client_session: u16, + client_serial: u32, + ) -> SerialResult { + if client_session != self.session_id { + return SerialResult::ResetRequired; + } + + if client_serial == self.serial { + return SerialResult::UpToDate; + } + + if self.deltas.is_empty() { + return SerialResult::ResetRequired; + } + + let oldest_serial = self.deltas.front().unwrap().serial; + let newest_serial = self.deltas.back().unwrap().serial; + + let min_supported = oldest_serial.wrapping_sub(1); + if client_serial < min_supported { + return SerialResult::ResetRequired; + } + + if client_serial > self.serial { + return SerialResult::ResetRequired; + } + + let mut result = Vec::new(); + for delta in &self.deltas { + if delta.serial > client_serial { + result.push(delta.clone()); + } + } + + if let Some(first) = result.first() { + if first.serial != client_serial.wrapping_add(1) { + return SerialResult::ResetRequired; + } + } else { + return SerialResult::UpToDate; + } + + SerialResult::Deltas(result) + } +} + +pub enum SerialResult { + /// Client is up to date. + UpToDate, + /// Return applicable deltas. + Deltas(Vec>), + /// Delta window cannot cover; reset required. + ResetRequired, +} diff --git a/src/rtr/error_type.rs b/src/rtr/error_type.rs new file mode 100644 index 0000000..95eaaa9 --- /dev/null +++ b/src/rtr/error_type.rs @@ -0,0 +1,98 @@ +use std::convert::TryFrom; +use std::fmt; + +#[repr(u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorCode { + CorruptData = 0, + InternalError = 1, + NoDataAvailable = 2, + InvalidRequest = 3, + UnsupportedProtocolVersion = 4, + UnsupportedPduType = 5, + WithdrawalOfUnknownRecord = 6, + DuplicateAnnouncement = 7, + UnexpectedProtocolVersion = 8, + AspaProviderListError = 9, + TransportFailed = 10, + OrderingError = 11, +} + +impl ErrorCode { + + #[inline] + pub fn as_u16(self) -> u16 { + self as u16 + } + + pub fn description(self) -> &'static str { + match self { + ErrorCode::CorruptData => + "Corrupt Data", + + ErrorCode::InternalError => + "Internal Error", + + ErrorCode::NoDataAvailable => + "No Data Available", + + ErrorCode::InvalidRequest => + "Invalid Request", + + ErrorCode::UnsupportedProtocolVersion => + "Unsupported Protocol Version", + + ErrorCode::UnsupportedPduType => + "Unsupported PDU Type", + + ErrorCode::WithdrawalOfUnknownRecord => + "Withdrawal of Unknown Record", + + ErrorCode::DuplicateAnnouncement => + "Duplicate Announcement Received", + + ErrorCode::UnexpectedProtocolVersion => + "Unexpected Protocol Version", + + ErrorCode::AspaProviderListError => + "ASPA Provider List Error", + + ErrorCode::TransportFailed => + "Transport Failed", + + ErrorCode::OrderingError => + "Ordering Error", + } + } +} + +impl TryFrom for ErrorCode { + type Error = (); + + fn try_from(value: u16) -> Result { + match value { + 0 => Ok(ErrorCode::CorruptData), + 1 => Ok(ErrorCode::InternalError), + 2 => Ok(ErrorCode::NoDataAvailable), + 3 => Ok(ErrorCode::InvalidRequest), + 4 => Ok(ErrorCode::UnsupportedProtocolVersion), + 5 => Ok(ErrorCode::UnsupportedPduType), + 6 => Ok(ErrorCode::WithdrawalOfUnknownRecord), + 7 => Ok(ErrorCode::DuplicateAnnouncement), + 8 => Ok(ErrorCode::UnexpectedProtocolVersion), + 9 => Ok(ErrorCode::AspaProviderListError), + 10 => Ok(ErrorCode::TransportFailed), + 11 => Ok(ErrorCode::OrderingError), + _ => Err(()), + } + } +} + +impl fmt::Display for ErrorCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} ({})", + self.description(), + *self as u16 + ) + } +} diff --git a/src/rtr/mod.rs b/src/rtr/mod.rs new file mode 100644 index 0000000..7062221 --- /dev/null +++ b/src/rtr/mod.rs @@ -0,0 +1,7 @@ +pub mod pdu; +pub mod cache; +pub mod payload; +mod store_db; +mod session; +mod error_type; +mod state; \ No newline at end of file diff --git a/src/rtr/payload.rs b/src/rtr/payload.rs new file mode 100644 index 0000000..c80b271 --- /dev/null +++ b/src/rtr/payload.rs @@ -0,0 +1,106 @@ +use std::fmt::Debug; +use std::sync::Arc; +use std::time::Duration; +use asn1_rs::nom::character::streaming::u64; +use serde::{Deserialize, Serialize}; +use crate::data_model::resources::as_resources::Asn; +use crate::data_model::resources::ip_resources::IPAddressPrefix; + + +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct Ski([u8; 20]); + + +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +pub struct RouteOrigin { + prefix: IPAddressPrefix, + max_length: u8, + asn: Asn, +} + +impl RouteOrigin { + pub fn prefix(&self) -> &IPAddressPrefix { + &self.prefix + } + + pub fn max_length(&self) -> u8 { + self.max_length + } + + pub fn asn(&self) -> Asn { + self.asn + } +} + + +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +pub struct RouterKey { + subject_key_identifier: Ski, + asn: Asn, + subject_public_key_info: Arc<[u8]>, +} + +impl RouterKey { + pub fn asn(&self) -> Asn { + self.asn + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +pub struct Aspa { + customer_asn: Asn, + provider_asns: Vec, +} + +impl Aspa { + pub fn customer_asn(&self) -> Asn { + self.customer_asn + } + + pub fn provider_asns(&self) -> &[Asn] { + &self.provider_asns + } +} + + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub enum Payload { + /// A route origin authorisation. + RouteOrigin(RouteOrigin), + + /// A BGPsec router key. + RouterKey(RouterKey), + + /// An ASPA unit. + Aspa(Aspa), +} + + +// Timing +#[derive(Clone, Copy, Debug)] +pub struct Timing { + /// The number of seconds until a client should refresh its data. + pub refresh: u32, + + /// The number of seconds a client whould wait before retrying to connect. + pub retry: u32, + + /// The number of secionds before data expires if not refreshed. + pub expire: u32 +} + +impl Timing { + pub fn refresh(self) -> Duration { + Duration::from_secs(u64::from(self.refresh)) + } + + pub fn retry(self) -> Duration { + Duration::from_secs(u64::from(self.retry)) + } + + pub fn expire(self) -> Duration { + Duration::from_secs(u64::from(self.expire)) + } + +} diff --git a/src/rtr/pdu.rs b/src/rtr/pdu.rs new file mode 100644 index 0000000..c83accd --- /dev/null +++ b/src/rtr/pdu.rs @@ -0,0 +1,868 @@ +use std::{cmp, mem}; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::Arc; +use crate::data_model::resources::as_resources::Asn; +use crate::rtr::payload::{Ski, Timing}; +use std::io; +use std::io::Write; +use tokio::io::{AsyncWrite}; +use anyhow::Result; + +use std::slice; +use anyhow::bail; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +pub const HEADER_LEN: u32 = 8; +pub const MAX_PDU_LEN: u32 = 65535; +pub const IPV4_PREFIX_LEN: u32 = 20; +pub const IPV6_PREFIX_LEN: u32 = 32; +pub const END_OF_DATA_V0_LEN: u32 = 12; +pub const END_OF_DATA_V1_LEN: u32 = 24; +pub const ZERO_16: u16 = 0; +pub const ZERO_8: u8 = 0; + +macro_rules! common { + ( $type:ident ) => { + #[allow(dead_code)] + impl $type { + /// Writes a value to a writer. + pub async fn write( + &self, + a: &mut A + ) -> Result<(), io::Error> { + a.write_all(self.as_ref()).await + } + } + + impl AsRef<[u8]> for $type { + fn as_ref(&self) -> &[u8] { + unsafe { + slice::from_raw_parts( + self as *const Self as *const u8, + mem::size_of::() + ) + } + } + } + + impl AsMut<[u8]> for $type { + fn as_mut(&mut self) -> &mut [u8] { + unsafe { + slice::from_raw_parts_mut( + self as *mut Self as *mut u8, + mem::size_of::() + ) + } + } + } + } +} + +macro_rules! concrete { + ( $type:ident ) => { + common!($type); + + #[allow(dead_code)] + impl $type { + /// Returns the value of the version field of the header. + pub fn version(&self) -> u8 { + self.header.version() + } + + /// Returns the value of the session field of the header. + /// + /// Note that this field is used for other purposes in some PDU + /// types. + pub fn session_id(&self) -> u16 { + self.header.session_id() + } + + /// Returns the PDU size. + /// + /// The size is returned as a `u32` since that type is used in + /// the header. + pub fn size() -> u32 { + mem::size_of::() as u32 + } + + /// Reads a value from a reader. + /// + /// If a value with a different PDU type is received, returns an + /// error. + pub async fn read( + sock: &mut Sock + ) -> Result { + let mut res = Self::default(); + sock.read_exact(res.header.as_mut()).await?; + if res.header.pdu() != Self::PDU { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + concat!( + "PDU type mismatch when expecting ", + stringify!($type) + ) + )) + } + if res.header.length() as usize != res.as_ref().len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + concat!( + "invalid length for ", + stringify!($type) + ) + )) + } + sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; + Ok(res) + } + + /// Tries to read a value from a reader. + /// + /// If a different PDU type is received, returns the header as + /// the error case of the ok case. + pub async fn try_read( + sock: &mut Sock + ) -> Result, io::Error> { + let mut res = Self::default(); + sock.read_exact(res.header.as_mut()).await?; + if res.header.pdu() == Error::PDU { + // Since we should drop the session after an error, we + // can safely ignore all the rest of the error for now. + return Ok(Err(res.header)) + } + if res.header.pdu() != Self::PDU { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + concat!( + "PDU type mismatch when expecting ", + stringify!($type) + ) + )) + } + if res.header.length() as usize != res.as_ref().len() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + concat!( + "invalid length for ", + stringify!($type) + ) + )) + } + sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; + Ok(Ok(res)) + } + + /// Reads only the payload part of a value from a reader. + /// + /// Assuming that the header was already read and is passed via + /// `header`, the function reads the rest of the PUD from the + /// reader and returns the complete value. + pub async fn read_payload( + header: Header, sock: &mut Sock + ) -> Result { + if header.length() as usize != mem::size_of::() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + concat!( + "invalid length for ", + stringify!($type), + " PDU" + ) + )) + } + let mut res = Self::default(); + sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; + res.header = header; + Ok(res) + } + } + } +} + + +// 所有PDU公共头部信息 +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct Header { + version: u8, + pdu: u8, + session_id: u16, + length: u32, +} + +impl Header { + + const LEN: usize = mem::size_of::(); + pub fn new(version: u8, pdu: u8, session: u16, length: u32) -> Self { + Header { + version, + pdu, + session_id: session.to_be(), + length: length.to_be(), + } + } + + pub async fn read(sock: &mut TcpStream) -> Result { + let mut buf = [0u8; HEADER_LEN]; + + // 1. 精确读取 8 字节 + sock.read_exact(&mut buf).await?; + + // 2. 手动解析(大端) + let version = buf[0]; + let pdu = buf[1]; + let reserved = u16::from_be_bytes([buf[2], buf[3]]); + let length = u32::from_be_bytes([ + buf[4], buf[5], buf[6], buf[7], + ]); + + // 3. 基础合法性校验 + + if length < HEADER_LEN{ + bail!("Invalid PDU length"); + } + + // 限制最大长度 + if length > MAX_PDU_LEN { + bail!("PDU too large"); + } + + Ok(Self { + version, + pdu, + session_id: reserved, + length, + }) + } + + pub fn version(self) -> u8{self.version} + + pub fn pdu(self) -> u8{self.pdu} + + pub fn session_id(self) -> u16{u16::from_be(self.session_id)} + + pub fn length(self) -> u32{u32::from_be(self.length)} + + pub fn pdu_len(self) -> Result { + usize::try_from(self.length()).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + "PDU too large for this system to handle", + ) + }) + } + + +} + +common!(Header); + +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct HeaderWithFlags { + version: u8, + pdu: u8, + flags: u8, + zero: u8, + length: u32, +} + +impl HeaderWithFlags { + pub fn new(version: u8, pdu: u8, flags: Flags, length: u32) -> Self { + HeaderWithFlags { + version, + pdu, + flags: flags.into_u8(), + zero: ZERO_8, + length: length.to_be(), + } + } + pub async fn read(sock: &mut TcpStream) -> Result { + let mut buf = [0u8; HEADER_LEN]; + + // 1. 精确读取 8 字节 + sock.read_exact(&mut buf).await?; + + // 2. 手动解析(大端) + let version = buf[0]; + let pdu = buf[1]; + let flags = buf[2]; + let zero = buf[3]; + let length = u32::from_be_bytes([ + buf[4], buf[5], buf[6], buf[7], + ]); + + // 3. 基础合法性校验 + if length < HEADER_LEN{ + bail!("Invalid PDU length"); + } + + // 限制最大长度 + if length > MAX_PDU_LEN { + bail!("PDU too large"); + } + + Ok(Self { + version, + pdu, + flags, + zero, + length, + }) + } + + pub fn version(self) -> u8{self.version} + + pub fn pdu(self) -> u8{self.pdu} + + pub fn flags(self) -> Flags{Flags(self.flags)} + + pub fn length(self) -> u32{u32::from_be(self.length)} +} + + +// Serial Notify +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct SerialNotify { + header: Header, + serial_number: u32, +} + +impl SerialNotify { + pub const PDU: u8 = 0; + + pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { + SerialNotify{ + header: Header::new(version, Self::PDU, session_id, Self::size()), + serial_number: serial_number.to_be() + } + } + +} +concrete!(SerialNotify); + + +// Serial Query +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct SerialQuery { + header: Header, + serial_number: u32, +} + +impl SerialQuery { + pub const PDU: u8 = 1; + + pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { + SerialQuery{ + header: Header::new(version, Self::PDU, session_id, Self::size()), + serial_number: serial_number.to_be() + } + } + + pub fn serial_number(self) -> u32 { + self.serial_number + } +} + +concrete!(SerialQuery); + + +// Reset Query +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct ResetQuery { + header: Header +} + +impl ResetQuery { + pub const PDU: u8 = 2; + + pub fn new(version: u8) -> Self { + ResetQuery { + header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN), + } + } +} + +concrete!(ResetQuery); + + +// Cache Response +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct CacheResponse { + header: Header, +} + +impl CacheResponse { + pub const PDU: u8 = 3; + + pub fn new(version: u8, session_id: u16) -> Self { + CacheResponse { + header: Header::new(version, Self::PDU, session_id, HEADER_LEN), + } + } +} + +concrete!(CacheResponse); + + +// Flags +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct Flags(u8); + +impl Flags { + pub fn new(raw: u8) -> Self { + Self(raw) + } + + pub fn is_announce(self) -> bool { + self.0 & 0x01 == 1 + } + + pub fn is_withdraw(self) -> bool { + !self.is_announce() + } + + pub fn into_u8(self) -> u8 { + self.0 + } +} + +// IPv4 Prefix +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct IPv4Prefix { + header: Header, + + flags: Flags, + prefix_len: u8, + max_len: u8, + zero: u8, + prefix: u32, + asn: u32 +} + +impl IPv4Prefix { + pub const PDU: u8 = 4; + pub fn new( + version: u8, + flags: Flags, + prefix_len: u8, + max_len: u8, + prefix: Ipv4Addr, + asn: Asn + ) -> Self { + IPv4Prefix { + header: Header::new(version, Self::PDU, ZERO_16, IPV4_PREFIX_LEN), + flags, + prefix_len, + max_len, + zero: ZERO_8, + prefix: u32::from(prefix).to_be(), + asn: asn.into_u32().to_be(), + } + } + + pub fn flag(self) -> Flags{self.flags} + + pub fn prefix_len(self) -> u8{self.prefix_len} + pub fn max_len(self) -> u8{self.max_len} + pub fn prefix(self) -> Ipv4Addr{u32::from_be(self.prefix).into()} + pub fn asn(self) -> Asn{u32::from_be(self.asn).into()} +} + +concrete!(IPv4Prefix); + +// IPv6 Prefix +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct IPv6Prefix { + header: Header, + + flags: Flags, + prefix_len: u8, + max_len: u8, + zero: u8, + prefix: u128, + asn: u32 +} + +impl IPv6Prefix { + pub const PDU: u8 = 6; + pub fn new( + version: u8, + flags: Flags, + prefix_len: u8, + max_len: u8, + prefix: Ipv6Addr, + asn: Asn + ) -> Self { + IPv6Prefix { + header: Header::new(version, Self::PDU, ZERO_16, IPV6_PREFIX_LEN), + flags, + prefix_len, + max_len, + zero: ZERO_8, + prefix: u128::from(prefix).to_be(), + asn: asn.into_u32().to_be(), + } + } + + pub fn flag(self) -> Flags{self.flags} + + pub fn prefix_len(self) -> u8{self.prefix_len} + pub fn max_len(self) -> u8{self.max_len} + pub fn prefix(self) -> Ipv6Addr{u128::from_be(self.prefix).into()} + pub fn asn(self) -> Asn{u32::from_be(self.asn).into()} +} + +concrete!(IPv6Prefix); + + +// End of Data +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum EndOfData { + V0(EndOfDataV0), + V1(EndOfDataV1), +} + +impl EndOfData { + pub fn new(version: u8, session_id: u16, serial_number: u32, timing: Timing) -> Self { + if version == 0 { + EndOfData::V0(EndOfDataV0::new(version, session_id, serial_number)) + } + else { + EndOfData::V1(EndOfDataV1::new(version, session_id, serial_number, timing)) + } + } + + +} + +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct EndOfDataV0 { + header: Header, + + serial_number: u32, +} + +impl EndOfDataV0 { + pub const PDU: u8 = 7; + + pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { + EndOfDataV0 { + header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V0_LEN), + serial_number: serial_number.to_be(), + } + } + + pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} +} +concrete!(EndOfDataV0); + + +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct EndOfDataV1 { + header: Header, + + serial_number: u32, + refresh_interval: u32, + retry_interval: u32, + expire_interval: u32, + +} + +impl EndOfDataV1 { + pub const PDU: u8 = 7; + + pub fn new(version: u8, session_id: u16, serial_number: u32, timing: Timing) -> Self { + EndOfDataV1 { + header: Header::new(version, Self::PDU, session_id, END_OF_DATA_V1_LEN), + serial_number: serial_number.to_be(), + refresh_interval: timing.refresh.to_be(), + retry_interval: timing.retry.to_be(), + expire_interval: timing.expire.to_be(), + } + } + + pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} + + pub fn timing(self) -> Timing{ + Timing { + refresh: u32::from_be(self.refresh_interval), + retry: u32::from_be(self.retry_interval), + expire: u32::from_be(self.expire_interval), + } + } +} +concrete!(EndOfDataV1); + +// Cache Reset +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +pub struct CacheReset { + header: Header, +} + +impl CacheReset { + pub const PDU: u8 = 8; + + pub fn new(version: u8) -> Self{ + CacheReset { + header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN) + } + } +} + +concrete!(CacheReset); + + +// Error Report +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +pub struct ErrorReport { + octets: Vec, +} + + +impl ErrorReport { + /// The PDU type of an error PDU. + pub const PDU: u8 = 10; + + /// Creates a new error PDU from components. + pub fn new( + version: u8, + error_code: u16, + pdu: impl AsRef<[u8]>, + text: impl AsRef<[u8]>, + ) -> Self { + let pdu = pdu.as_ref(); + let text = text.as_ref(); + + let size = + mem::size_of::
() + + 2 * mem::size_of::() + + pdu.len() + text.len() + ; + let header = Header::new( + version, 10, error_code, u32::try_from(size).unwrap() + ); + + let mut octets = Vec::with_capacity(size); + octets.extend_from_slice(header.as_ref()); + octets.extend_from_slice( + u32::try_from(pdu.len()).unwrap().to_be_bytes().as_ref() + ); + octets.extend_from_slice(pdu); + octets.extend_from_slice( + u32::try_from(text.len()).unwrap().to_be_bytes().as_ref() + ); + octets.extend_from_slice(text); + + ErrorReport { octets } + } + + /// Skips over the payload of the error PDU. + pub async fn skip_payload( + header: Header, sock: &mut Sock + ) -> Result<(), io::Error> { + let Some(mut remaining) = header.pdu_len()?.checked_sub( + mem::size_of::
() + ) else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "PDU size smaller than header size", + )) + }; + + let mut buf = [0u8; 1024]; + while remaining > 0 { + let read_len = cmp::min(remaining, mem::size_of_val(&buf)); + let read = sock.read( + // Safety: We limited the length to the buffer size. + unsafe { buf.get_unchecked_mut(..read_len) } + ).await?; + remaining -= read; + } + Ok(()) + } + + /// Writes the PUD to a writer. + pub async fn write( + &self, a: &mut A + ) -> Result<(), io::Error> { + a.write_all(self.as_ref()).await + } +} + + +// TODO: 补全 +// Router Key +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +pub struct RouterKey { + header: HeaderWithFlags, + + flags: Flags, + + ski: Ski, + asn: Asn, + subject_public_key_info: Arc<[u8]>, +} + +impl RouterKey { + + pub const PDU: u8 = 9; + + pub async fn write( + &self, + w: &mut A, + ) -> Result<(), io::Error> { + + let length = HEADER_LEN + + 1 // flags + // + self.ski.as_ref().len() + + 4 // ASN + + self.subject_public_key_info.len() as u32; + + let header = HeaderWithFlags::new( + self.header.version(), + Self::PDU, + self.flags, + length, + ); + + w.write_all(&[ + header.version(), + header.pdu(), + header.flags().into_u8(), + ZERO_8, + ]).await?; + + w.write_all(&length.to_be_bytes()).await?; + // w.write_all(self.ski.as_ref()).await?; + w.write_all(&self.asn.into_u32().to_be_bytes()).await?; + w.write_all(&self.subject_public_key_info).await?; + + Ok(()) + } +} + + + +// ASPA +pub struct Aspa{ + header: HeaderWithFlags, + + customer_asn: u32, + provider_asns: Vec +} + +impl Aspa { + pub const PDU: u8 = 11; + + pub async fn write( + &self, + w: &mut A, + ) -> Result<(), io::Error> { + + let length = HEADER_LEN + + 1 + + 4 + + (self.provider_asns.len() as u32 * 4); + + let header = HeaderWithFlags::new( + self.header.version(), + Self::PDU, + Flags::new(self.header.flags), + length, + ); + + w.write_all(&[ + header.version(), + header.pdu(), + header.flags().into_u8(), + ZERO_8, + ]).await?; + + w.write_all(&length.to_be_bytes()).await?; + w.write_all(&self.customer_asn.to_be_bytes()).await?; + + for asn in &self.provider_asns { + w.write_all(&asn.to_be_bytes()).await?; + } + + Ok(()) + } + +} + + +//--- AsRef and AsMut +impl AsRef<[u8]> for ErrorReport { + fn as_ref(&self) -> &[u8] { + self.octets.as_ref() + } +} + +impl AsMut<[u8]> for ErrorReport { + fn as_mut(&mut self) -> &mut [u8] { + self.octets.as_mut() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::duplex; + + #[tokio::test] + async fn test_serial_notify_roundtrip() { + let (mut client, mut server) = duplex(1024); + + let original = SerialNotify::new(1, 42, 100); + + // 写入 + tokio::spawn(async move { + original.write(&mut client).await.unwrap(); + }); + + // 读取 + let decoded = SerialNotify::read(&mut server).await.unwrap(); + + assert_eq!(decoded.version(), 1); + assert_eq!(decoded.session_id(), 42); + assert_eq!(decoded.serial_number, 100u32.to_be()); + } + + #[tokio::test] + async fn test_ipv4_prefix_roundtrip() { + use std::net::Ipv4Addr; + + let (mut client, mut server) = duplex(1024); + + let prefix = IPv4Prefix::new( + 1, + Flags::new(1), + 24, + 24, + Ipv4Addr::new(192,168,0,0), + 65000u32.into(), + ); + + tokio::spawn(async move { + prefix.write(&mut client).await.unwrap(); + }); + + let decoded = IPv4Prefix::read(&mut server).await.unwrap(); + + assert_eq!(decoded.prefix_len(), 24); + assert_eq!(decoded.max_len(), 24); + assert_eq!(decoded.prefix(), Ipv4Addr::new(192,168,0,0)); + assert_eq!(decoded.flag().is_announce(), true); + } +} diff --git a/src/rtr/session.rs b/src/rtr/session.rs new file mode 100644 index 0000000..05892a3 --- /dev/null +++ b/src/rtr/session.rs @@ -0,0 +1,280 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use tokio::io; +use tokio::net::TcpStream; +use tracing::warn; + +use crate::rtr::cache::{Delta, RtrCache, SerialResult}; +use crate::rtr::error_type::ErrorCode; +use crate::rtr::payload::{Payload, RouteOrigin, Timing}; +use crate::rtr::pdu::{ + CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, Header, IPv4Prefix, IPv6Prefix, + ResetQuery, SerialQuery, +}; + +const SUPPORTED_MAX_VERSION: u8 = 2; +const SUPPORTED_MIN_VERSION: u8 = 0; + +const ANNOUNCE_FLAG: u8 = 1; +const WITHDRAW_FLAG: u8 = 0; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SessionState { + Connected, + Established, + Closed, +} + +pub struct RtrSession { + cache: Arc, + version: Option, + stream: TcpStream, + state: SessionState, +} + +impl RtrSession { + pub fn new(cache: Arc, stream: TcpStream) -> Self { + Self { + cache, + version: None, + stream, + state: SessionState::Connected, + } + } + + pub async fn run(mut self) -> Result<()> { + loop { + let header = match Header::read(&mut self.stream).await { + Ok(h) => h, + Err(_) => return Ok(()), + }; + + if self.version.is_none() { + self.negotiate_version(header.version()).await?; + } else if header.version() != self.version.unwrap() { + self.send_unsupported_version(self.version.unwrap()).await?; + bail!("version changed within session"); + } + + match header.pdu() { + ResetQuery::PDU => { + let _ = ResetQuery::read_payload(header, &mut self.stream).await?; + self.handle_reset_query().await?; + } + SerialQuery::PDU => { + let query = SerialQuery::read_payload(header, &mut self.stream).await?; + let session_id = query.session_id(); + let serial = u32::from_be(query.serial_number()); + self.handle_serial(session_id, serial).await?; + } + ErrorReport::PDU => { + let _ = ErrorReport::skip_payload(header, &mut self.stream).await; + self.state = SessionState::Closed; + return Ok(()); + } + _ => { + self.send_error(header.version(), ErrorCode::UnsupportedPduType, Some(&header), &[]) + .await?; + return Ok(()); + } + } + } + } + + async fn negotiate_version(&mut self, router_version: u8) -> io::Result { + if router_version < SUPPORTED_MIN_VERSION { + self.send_unsupported_version(SUPPORTED_MIN_VERSION).await?; + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "unsupported lower protocol version", + )); + } + + if router_version > SUPPORTED_MAX_VERSION { + self.send_unsupported_version(SUPPORTED_MAX_VERSION).await?; + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "router version higher than cache", + )); + } + + self.version = Some(router_version); + Ok(router_version) + } + + async fn send_unsupported_version(&mut self, cache_version: u8) -> io::Result<()> { + ErrorReport::new( + cache_version, + ErrorCode::UnsupportedProtocolVersion.as_u16(), + &[], + ErrorCode::UnsupportedProtocolVersion.description(), + ) + .write(&mut self.stream) + .await + } + + async fn handle_reset_query(&mut self) -> Result<()> { + self.state = SessionState::Established; + + let snapshot = self.cache.snapshot(); + self.write_cache_response().await?; + self.send_payloads(snapshot.payloads(), true).await?; + self.write_end_of_data(self.cache.session_id(), self.cache.serial()) + .await?; + + Ok(()) + } + + async fn handle_serial(&mut self, client_session: u16, client_serial: u32) -> Result<()> { + let current_session = self.cache.session_id(); + let current_serial = self.cache.serial(); + + match self.cache.get_deltas_since(client_session, client_serial) { + SerialResult::ResetRequired => { + self.write_cache_reset().await?; + return Ok(()); + } + SerialResult::UpToDate => { + self.write_end_of_data(current_session, current_serial) + .await?; + return Ok(()); + } + SerialResult::Deltas(deltas) => { + self.write_cache_response().await?; + for delta in deltas { + self.send_delta(&delta).await?; + } + self.write_end_of_data(current_session, current_serial) + .await?; + } + } + + self.state = SessionState::Established; + Ok(()) + } + + async fn write_cache_response(&mut self) -> Result<()> { + let version = self.version.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") + })?; + + CacheResponse::new(version, self.cache.session_id()) + .write(&mut self.stream) + .await?; + Ok(()) + } + + async fn write_cache_reset(&mut self) -> Result<()> { + let version = self.version.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") + })?; + + CacheReset::new(version).write(&mut self.stream).await?; + Ok(()) + } + + async fn write_end_of_data(&mut self, session_id: u16, serial: u32) -> Result<()> { + let version = self.version.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") + })?; + + let timing = self.timing(); + let end = EndOfData::new(version, session_id, serial, timing); + match end { + EndOfData::V0(pdu) => pdu.write(&mut self.stream).await?, + EndOfData::V1(pdu) => pdu.write(&mut self.stream).await?, + } + + Ok(()) + } + + async fn send_payloads(&mut self, payloads: Vec, announce: bool) -> Result<()> { + for payload in payloads { + self.send_payload(&payload, announce).await?; + } + Ok(()) + } + + async fn send_delta(&mut self, delta: &Arc) -> Result<()> { + for payload in delta.withdrawn() { + self.send_payload(payload, false).await?; + } + for payload in delta.announced() { + self.send_payload(payload, true).await?; + } + Ok(()) + } + + async fn send_payload(&mut self, payload: &Payload, announce: bool) -> Result<()> { + match payload { + Payload::RouteOrigin(origin) => { + self.send_route_origin(origin, announce).await?; + } + Payload::RouterKey(_) => { + warn!("router key payload not supported yet"); + } + Payload::Aspa(_) => { + warn!("aspa payload not supported yet"); + } + } + Ok(()) + } + + async fn send_route_origin(&mut self, origin: &RouteOrigin, announce: bool) -> Result<()> { + let version = self.version.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "version not negotiated") + })?; + + let flags = Flags::new(if announce { + ANNOUNCE_FLAG + } else { + WITHDRAW_FLAG + }); + + let prefix = origin.prefix(); + let prefix_len = prefix.prefix_length; + let max_len = origin.max_length(); + + if let Some(v4) = prefix.address.to_ipv4() { + IPv4Prefix::new(version, flags, prefix_len, max_len, v4, origin.asn()) + .write(&mut self.stream) + .await?; + } else { + let v6 = prefix.address.to_ipv6(); + IPv6Prefix::new(version, flags, prefix_len, max_len, v6, origin.asn()) + .write(&mut self.stream) + .await?; + } + + Ok(()) + } + + async fn send_error( + &mut self, + version: u8, + code: ErrorCode, + offending_header: Option<&Header>, + text: &[u8], + ) -> io::Result<()> { + let offending = offending_header + .map(|h| h.as_ref()) + .unwrap_or(&[]); + + ErrorReport::new(version, code.as_u16(), offending, text) + .write(&mut self.stream) + .await + } + + fn timing(&self) -> Timing { + let refresh = self.cache.refresh_interval().as_secs(); + let retry = self.cache.retry_interval().as_secs(); + let expire = self.cache.expire_interval().as_secs(); + + Timing { + refresh: refresh.min(u32::MAX as u64) as u32, + retry: retry.min(u32::MAX as u64) as u32, + expire: expire.min(u32::MAX as u64) as u32, + } + } +} diff --git a/src/rtr/state.rs b/src/rtr/state.rs new file mode 100644 index 0000000..b33c23a --- /dev/null +++ b/src/rtr/state.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct State { + session_id: u16, + serial: u32, +} + +impl State { + pub fn session_id(self) -> u16 { + self.session_id + } + + pub fn serial(self) -> u32 { + self.serial + } +} diff --git a/src/rtr/store_db.rs b/src/rtr/store_db.rs new file mode 100644 index 0000000..e8f6400 --- /dev/null +++ b/src/rtr/store_db.rs @@ -0,0 +1,300 @@ +use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch}; +use anyhow::{anyhow, Result}; +use serde::{de::DeserializeOwned, Serialize}; +use std::path::Path; +use std::sync::Arc; +use tokio::task; + +use crate::rtr::cache::{Delta, Snapshot}; +use crate::rtr::state::State; + +const CF_META: &str = "meta"; +const CF_SNAPSHOT: &str = "snapshot"; +const CF_DELTA: &str = "delta"; + +const META_STATE: &[u8] = b"state"; +const META_SESSION_ID: &[u8] = b"session_id"; +const META_SERIAL: &[u8] = b"serial"; +const META_DELTA_MIN: &[u8] = b"delta_min"; +const META_DELTA_MAX: &[u8] = b"delta_max"; + +const DELTA_KEY_PREFIX: u8 = b'd'; + +fn delta_key(serial: u32) -> [u8; 5] { + let mut key = [0u8; 5]; + key[0] = DELTA_KEY_PREFIX; + key[1..].copy_from_slice(&serial.to_be_bytes()); + key +} + +fn delta_key_serial(key: &[u8]) -> Option { + if key.len() != 5 || key[0] != DELTA_KEY_PREFIX { + return None; + } + let mut bytes = [0u8; 4]; + bytes.copy_from_slice(&key[1..]); + Some(u32::from_be_bytes(bytes)) +} + +#[derive(Clone)] +pub struct RtrStore { + db: Arc, +} + +impl RtrStore { + /// Open or create DB with required column families. + pub fn open>(path: P) -> Result { + let mut opts = Options::default(); + opts.create_if_missing(true); + opts.create_missing_column_families(true); + + let cfs = vec![ + ColumnFamilyDescriptor::new(CF_META, Options::default()), + ColumnFamilyDescriptor::new(CF_SNAPSHOT, Options::default()), + ColumnFamilyDescriptor::new(CF_DELTA, Options::default()), + ]; + + let db = Arc::new(DB::open_cf_descriptors(&opts, path, cfs)?); + + Ok(Self { db }) + } + + /// Common serialize/put. + fn put_cf(&self, cf: &str, key: &[u8], value: &T) -> Result<()> { + let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + let data = serde_json::to_vec(value)?; + self.db.put_cf(cf_handle, key, data)?; + Ok(()) + } + + /// Common get/deserialize. + fn get_cf(&self, cf: &str, key: &[u8]) -> Result> { + let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + if let Some(value) = self.db.get_cf(cf_handle, key)? { + let obj = serde_json::from_slice(&value)?; + Ok(Some(obj)) + } else { + Ok(None) + } + } + + /// Common delete. + fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> { + let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + self.db.delete_cf(cf_handle, key)?; + Ok(()) + } + + // =============================== + // Meta/state + // =============================== + + pub fn set_state(&self, state: &State) -> Result<()> { + self.put_cf(CF_META, META_STATE, &state) + } + + pub fn get_state(&self) -> Result> { + self.get_cf(CF_META, META_STATE) + } + + pub fn set_meta(&self, meta: &State) -> Result<()> { + self.set_state(meta) + } + + pub fn get_meta(&self) -> Result> { + self.get_state() + } + + pub fn set_session_id(&self, session_id: u16) -> Result<()> { + self.put_cf(CF_META, META_SESSION_ID, &session_id) + } + + pub fn get_session_id(&self) -> Result> { + self.get_cf(CF_META, META_SESSION_ID) + } + + pub fn set_serial(&self, serial: u32) -> Result<()> { + self.put_cf(CF_META, META_SERIAL, &serial) + } + + pub fn get_serial(&self) -> Result> { + self.get_cf(CF_META, META_SERIAL) + } + + pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> { + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let mut batch = WriteBatch::default(); + batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?); + batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?); + self.db.write(batch)?; + Ok(()) + } + + pub fn get_delta_window(&self) -> Result> { + let min: Option = self.get_cf(CF_META, META_DELTA_MIN)?; + let max: Option = self.get_cf(CF_META, META_DELTA_MAX)?; + + match (min, max) { + (Some(min), Some(max)) => Ok(Some((min, max))), + (None, None) => Ok(None), + _ => Err(anyhow!("Inconsistent DB state: delta window mismatch")), + } + } + + pub fn delete_state(&self) -> Result<()> { + self.delete_cf(CF_META, META_STATE) + } + + pub fn delete_serial(&self) -> Result<()> { + self.delete_cf(CF_META, META_SERIAL) + } + + // =============================== + // Snapshot + // =============================== + + pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> { + let cf_handle = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let mut batch = WriteBatch::default(); + let data = serde_json::to_vec(snapshot)?; + batch.put_cf(cf_handle, b"current", data); + self.db.write(batch)?; + Ok(()) + } + + pub fn get_snapshot(&self) -> Result> { + self.get_cf(CF_SNAPSHOT, b"current") + } + + pub fn delete_snapshot(&self) -> Result<()> { + self.delete_cf(CF_SNAPSHOT, b"current") + } + + pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> { + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let mut batch = WriteBatch::default(); + + batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); + batch.put_cf(meta_cf, META_STATE, serde_json::to_vec(state)?); + batch.put_cf( + meta_cf, + META_SESSION_ID, + serde_json::to_vec(&state.clone().session_id())?, + ); + batch.put_cf( + meta_cf, + META_SERIAL, + serde_json::to_vec(&state.clone().serial())?, + ); + + self.db.write(batch)?; + Ok(()) + } + + pub fn save_snapshot_and_meta( + &self, + snapshot: &Snapshot, + session_id: u16, + serial: u32, + ) -> Result<()> { + let mut batch = WriteBatch::default(); + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + + batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); + batch.put_cf(meta_cf, META_SESSION_ID, serde_json::to_vec(&session_id)?); + batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); + self.db.write(batch)?; + Ok(()) + } + + pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> { + let mut batch = WriteBatch::default(); + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); + batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); + self.db.write(batch)?; + Ok(()) + } + + pub async fn save_snapshot_and_serial_async( + self: Arc, + snapshot: Snapshot, + serial: u32, + ) -> Result<()> { + let snapshot_bytes = serde_json::to_vec(&snapshot)?; + let serial_bytes = serde_json::to_vec(&serial)?; + + task::spawn_blocking(move || { + let mut batch = WriteBatch::default(); + let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + batch.put_cf(snapshot_cf, b"current", snapshot_bytes); + batch.put_cf(meta_cf, META_SERIAL, serial_bytes); + self.db.write(batch)?; + Ok::<_, anyhow::Error>(()) + }) + .await??; + + Ok(()) + } + + pub fn load_snapshot_and_state(&self) -> Result> { + let snapshot: Option = self.get_snapshot()?; + let state: Option = self.get_state()?; + match (snapshot, state) { + (Some(snap), Some(state)) => Ok(Some((snap, state))), + (None, None) => Ok(None), + _ => Err(anyhow!("Inconsistent DB state: snapshot and state mismatch")), + } + } + + pub fn load_snapshot_and_serial(&self) -> Result> { + let snapshot: Option = self.get_snapshot()?; + let serial: Option = self.get_serial()?; + match (snapshot, serial) { + (Some(snap), Some(serial)) => Ok(Some((snap, serial))), + (None, None) => Ok(None), + _ => Err(anyhow!("Inconsistent DB state: snapshot and serial mismatch")), + } + } + + // =============================== + // Delta + // =============================== + + pub fn save_delta(&self, delta: &Delta) -> Result<()> { + self.put_cf(CF_DELTA, &delta_key(delta.serial()), delta) + } + + pub fn get_delta(&self, serial: u32) -> Result> { + self.get_cf(CF_DELTA, &delta_key(serial)) + } + + pub fn load_deltas_since(&self, serial: u32) -> Result> { + let cf_handle = self.db.cf_handle(CF_DELTA).ok_or_else(|| anyhow!("CF_DELTA not found"))?; + let mut out = Vec::new(); + let start_key = delta_key(serial.wrapping_add(1)); + + let iter = self + .db + .iterator_cf(cf_handle, IteratorMode::From(&start_key, Direction::Forward)); + + for (key, value) in iter { + let parsed = delta_key_serial(&key).ok_or_else(|| anyhow!("Invalid delta key"))?; + if parsed <= serial { + continue; + } + let delta: Delta = serde_json::from_slice(&value)?; + out.push(delta); + } + + Ok(out) + } + + pub fn delete_delta(&self, serial: u32) -> Result<()> { + self.delete_cf(CF_DELTA, &delta_key(serial)) + } +} diff --git a/src/slurm/mod.rs b/src/slurm/mod.rs new file mode 100644 index 0000000..0ade6c8 --- /dev/null +++ b/src/slurm/mod.rs @@ -0,0 +1 @@ +mod slurm; \ No newline at end of file diff --git a/src/slurm/slurm.rs b/src/slurm/slurm.rs new file mode 100644 index 0000000..7b4b268 --- /dev/null +++ b/src/slurm/slurm.rs @@ -0,0 +1,80 @@ +use std::io; +use crate::data_model::resources::as_resources::Asn; + + +#[derive(Debug, thiserror::Error)] +pub enum SlurmError { + #[error("Read slurm from reader error")] + SlurmFromReader(), +} + + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SlurmFile { + pub version: u32, + pub validation_output_filters: ValidationOutputFilters, + pub locally_added_assertions: LocallyAddedAssertions, +} + +impl SlurmFile { + pub fn new(filters: ValidationOutputFilters, + assertions: LocallyAddedAssertions,) -> Self { + let version = 1; + SlurmFile { + version, + validation_output_filters: filters, + locally_added_assertions: assertions, + } + } + + // pub fn from_reader(reader: impl io::Read)-> Result { + // + // } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidationOutputFilters { + pub prefix_filters: Vec, + pub bgpset_filters: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Comment(String); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrefixFilter { + pub prefix: String, + pub asn: Asn, + pub comment: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BgpsecFilter { + pub asn: Asn, + pub ski: u8, + pub comment: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocallyAddedAssertions { + pub prefix_assertions: Vec, + pub bgpsec_assertions: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrefixAssertion { + pub prefix: String, + pub asn: Asn, + pub max_prefix_length: u8, + pub comment: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BgpsecAssertion { + pub asn: Asn, + pub ski: u8, + pub router_public_key: u8, + pub comment: Option, +} + +