diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..6eef9bf --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +.git +.gitignore +.idea +target +tmp_slurm_output.json +rtr-db +tests +specs +scripts +README.md diff --git a/.gitignore b/.gitignore index 6e61639..1d9b68b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ target/ Cargo.lock -rtr-db/ \ No newline at end of file +rtr-db/ +.idea/ diff --git a/Cargo.toml b/Cargo.toml index 85d3734..f706324 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ rand = "0.10.0" rocksdb = { version = "0.21.0", default-features = false } serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" +base64 = "0.22" anyhow = "1" tracing = "0.1.44" sha2 = "0.10" diff --git a/data/20260324T000037Z-sng1.ccr b/data/20260324T000037Z-sng1.ccr new file mode 100644 index 0000000..128af6d Binary files /dev/null and b/data/20260324T000037Z-sng1.ccr differ diff --git a/data/20260324T000138Z-zur1.ccr b/data/20260324T000138Z-zur1.ccr new file mode 100644 index 0000000..3040cc1 Binary files /dev/null and b/data/20260324T000138Z-zur1.ccr differ diff --git a/data/example.slurm b/data/example.slurm new file mode 100644 index 0000000..0a712ee --- /dev/null +++ b/data/example.slurm @@ -0,0 +1,23 @@ +{ + "slurmVersion": 2, + "validationOutputFilters": { + "prefixFilters": [ + { + "prefix": "24.0.0.0/8", + "comment": "Filter many VRPs in current CCR sample" + } + ], + "bgpsecFilters": [], + "aspaFilters": [ + { + "customerAsn": 80, + "comment": "Filter one ASPA known to exist in current CCR sample" + } + ] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [], + "aspaAssertions": [] + } +} diff --git a/deploy/DEPLOYMENT.md b/deploy/DEPLOYMENT.md new file mode 100644 index 0000000..a5bb9cb --- /dev/null +++ b/deploy/DEPLOYMENT.md @@ -0,0 +1,40 @@ +# Deployment (Supervisor + Docker Compose) + +This project runs `src/main.rs` as a long-running server that: + +1. loads latest `.ccr` from a configured directory, +2. applies optional SLURM filtering, +3. starts RTR server. + +`supervisord` is used as PID 1 in container to keep the process managed and auto-restarted. + +## Files + +- `deploy/Dockerfile` +- `deploy/supervisord.conf` +- `deploy/docker-compose.yml` + +## Runtime Paths in Container + +- CCR directory: `/app/data` +- RocksDB directory: `/app/rtr-db` +- SLURM directory: `/app/slurm` +- TLS cert directory (optional): `/app/certs` + +## Start + +```bash +docker compose -f deploy/docker-compose.yml up -d --build +``` + +## Stop + +```bash +docker compose -f deploy/docker-compose.yml down +``` + +## Logs + +```bash +docker compose -f deploy/docker-compose.yml logs -f rpki-rtr +``` diff --git a/deploy/Dockerfile b/deploy/Dockerfile new file mode 100644 index 0000000..fd19891 --- /dev/null +++ b/deploy/Dockerfile @@ -0,0 +1,34 @@ +FROM rust:1.86-bookworm AS builder + +WORKDIR /build + +COPY Cargo.toml Cargo.lock ./ +COPY src ./src + +RUN cargo build --release --bin rpki + +FROM debian:bookworm-slim AS runtime + +RUN apt-get update \ + && apt-get install -y --no-install-recommends ca-certificates supervisor \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY --from=builder /build/target/release/rpki /usr/local/bin/rpki +COPY deploy/supervisord.conf /etc/supervisor/conf.d/rpki-rtr.conf + +RUN mkdir -p /app/data /app/rtr-db /app/certs /app/slurm /var/log/supervisor + +ENV RPKI_RTR_ENABLE_TLS=false \ + RPKI_RTR_TCP_ADDR=0.0.0.0:323 \ + RPKI_RTR_TLS_ADDR=0.0.0.0:324 \ + RPKI_RTR_DB_PATH=/app/rtr-db \ + RPKI_RTR_CCR_DIR=/app/data \ + RPKI_RTR_SLURM_DIR=/app/slurm \ + RPKI_RTR_REFRESH_INTERVAL_SECS=300 \ + RPKI_RTR_STRICT_CCR_VALIDATION=false + +EXPOSE 323 324 + +CMD ["supervisord", "-n", "-c", "/etc/supervisor/conf.d/rpki-rtr.conf"] diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml new file mode 100644 index 0000000..0dcf0e0 --- /dev/null +++ b/deploy/docker-compose.yml @@ -0,0 +1,28 @@ +version: "3.9" + +services: + rpki-rtr: + build: + context: .. + dockerfile: deploy/Dockerfile + image: rpki-rtr:latest + container_name: rpki-rtr + restart: unless-stopped + ports: + - "323:323" + - "324:324" + environment: + RPKI_RTR_ENABLE_TLS: "false" + RPKI_RTR_TCP_ADDR: "0.0.0.0:323" + RPKI_RTR_TLS_ADDR: "0.0.0.0:324" + RPKI_RTR_DB_PATH: "/app/rtr-db" + RPKI_RTR_CCR_DIR: "/app/data" + RPKI_RTR_SLURM_DIR: "/app/slurm" + RPKI_RTR_STRICT_CCR_VALIDATION: "false" + RPKI_RTR_REFRESH_INTERVAL_SECS: "300" + volumes: + - ../data:/app/data:ro + - ../rtr-db:/app/rtr-db + - ../data:/app/slurm:ro + # TLS mode example: + # - ../certs:/app/certs:ro diff --git a/deploy/supervisord.conf b/deploy/supervisord.conf new file mode 100644 index 0000000..f35bc5b --- /dev/null +++ b/deploy/supervisord.conf @@ -0,0 +1,18 @@ +[supervisord] +nodaemon=true +logfile=/dev/null +pidfile=/tmp/supervisord.pid + +[program:rpki-rtr] +command=/usr/local/bin/rpki +autostart=true +autorestart=true +startsecs=2 +startretries=3 +stopsignal=TERM +stopasgroup=true +killasgroup=true +stdout_logfile=/dev/fd/1 +stdout_logfile_maxbytes=0 +stderr_logfile=/dev/fd/2 +stderr_logfile_maxbytes=0 diff --git a/specs/10_slurm.md b/specs/10_slurm.md index 698ce2a..8a2edab 100644 --- a/specs/10_slurm.md +++ b/specs/10_slurm.md @@ -1,34 +1,23 @@ -# 10. SLURM(Simplified Local Internet Number Resource Management with the RPKI) +# 10. SLURM(Simplified Local Internet Number Resource Management with the RPKI) -## 10.1 对象定位 +## 10.1 目标与范围 -SLURM是一个JSON文件,允许 RPKI 依赖方在本地“覆盖/修正/忽略”来自上游RPKI数据的内容,而不需要修改或伪造原始RPKI对象。 +SLURM 用于让 RP(Relying Party)在本地对上游 RPKI 验证结果做“过滤”和“补充断言”,而不修改上游发布对象。 -## 10.2 数据格式 (RFC 8416 §3) +本文档基于: +- RFC 8416(SLURM v1,ROA/BGPsec) +- draft-ietf-sidrops-aspa-slurm-04(SLURM v2,新增 ASPA) -### SLURM +## 10.2 版本与顶层结构 -SLURM是一个只包含一个JSON对象的文件。格式要求如下(RFC 8416 §3.2): +### 10.2.1 SLURM v1(RFC 8416) -```text -A SLURM file consists of a single JSON object containing the -following members: - o A "slurmVersion" member that MUST be set to 1, encoded as a number - o A "validationOutputFilters" member (Section 3.3), whose value is - an object. The object MUST contain exactly two members: - * A "prefixFilters" member, whose value is described in - Section 3.3.1. - * A "bgpsecFilters" member, whose value is described in - Section 3.3.2. - o A "locallyAddedAssertions" member (Section 3.4), whose value is an - object. The object MUST contain exactly two members: - * A "prefixAssertions" member, whose value is described in - Section 3.4.1. - * A "bgpsecAssertions" member, whose value is described in - Section 3.4.2. -``` +`slurmVersion` 必须为 `1`,且顶层 JSON 对象必须包含且仅包含以下成员: +- `slurmVersion` +- `validationOutputFilters`(必须包含 `prefixFilters`、`bgpsecFilters`) +- `locallyAddedAssertions`(必须包含 `prefixAssertions`、`bgpsecAssertions`) -一个空的SLURM json结构体如下: +空策略示例: ```json { @@ -44,193 +33,154 @@ following members: } ``` -### prefixFilters -其中`prefixFilters`格式要求如下(RFC 8416 §3.3.1): +### 10.2.2 SLURM v2(draft-04) + +`slurmVersion` 必须为 `2`,在 v1 基础上扩展 ASPA 两类成员: +- `validationOutputFilters.aspaFilters` +- `locallyAddedAssertions.aspaAssertions` + +空策略示例: -```text -The above is expressed as a value of the "prefixFilters" member, as -an array of zero or more objects. Each object MUST contain either 1) -one of the following members or 2) one of each of the following -members. - o A "prefix" member, whose value is a string representing either an - IPv4 prefix (see Section 3.1 of [RFC4632]) or an IPv6 prefix (see - [RFC5952]). - o An "asn" member, whose value is a number. - In addition, each object MAY contain one optional "comment" member, - whose value is a string. -``` -示例: ```json -"prefixFilters": [ - { - "prefix": "192.0.2.0/24", - "comment": "All VRPs encompassed by prefix" - }, - { - "asn": 64496, - "comment": "All VRPs matching ASN" - }, - { - "prefix": "198.51.100.0/24", - "asn": 64497, - "comment": "All VRPs encompassed by prefix, matching ASN" +{ + "slurmVersion": 2, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [], + "aspaFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [], + "aspaAssertions": [] } -] +} ``` -### bgpsecFilters -`bgpsecFilters`格式要求如下(RFC 8416 §3.3.2) +## 10.3 字段规范(RFC 8416) -```text -The above is expressed as a value of the "bgpsecFilters" member, as -an array of zero or more objects. Each object MUST contain one of -either, or one each of both following members: - o An "asn" member, whose value is a number - o An "SKI" member, whose value is the Base64 encoding without - trailing ’=’ (Section 5 of [RFC4648]) of the certificate’s Subject - Key Identifier as described in Section 4.8.2 of [RFC6487]. (This - is the value of the ASN.1 OCTET STRING without the ASN.1 tag or - length fields.) -In addition, each object MAY contain one optional "comment" member, -whose value is a string. -``` +### 10.3.1 `prefixFilters` + +数组元素每项: +- 必须至少包含一个:`prefix` 或 `asn` +- 可选:`comment` + +匹配规则: +- 若配置了 `prefix`:匹配“被该前缀覆盖(encompassed)”的 VRP 前缀 +- 若配置了 `asn`:匹配该 ASN +- 同时配置时:两者都要匹配 + +### 10.3.2 `bgpsecFilters` + +数组元素每项: +- 必须至少包含一个:`asn` 或 `SKI` +- 可选:`comment` + +匹配规则: +- 按 `asn`/`SKI` 单独或联合匹配 Router Key(BGPsec) + +### 10.3.3 `prefixAssertions` + +数组元素每项: +- 必须:`prefix`、`asn` +- 可选:`maxPrefixLength`、`comment` + +约束: +- 若给出 `maxPrefixLength`,应满足 `prefix 长度 <= maxPrefixLength <= 地址位宽(IPv4=32, IPv6=128)` + +### 10.3.4 `bgpsecAssertions` + +数组元素每项: +- 必须:`asn`、`SKI`、`routerPublicKey` +- 可选:`comment` + +## 10.4 ASPA 扩展(draft-ietf-sidrops-aspa-slurm-04) + +### 10.4.1 `aspaFilters` + +数组元素每项: +- 必须:`customerAsn` +- 可选:`comment` + +匹配规则: +- 当 VAP(Validated ASPA Payload)的 `customerAsn` 等于过滤器 `customerAsn` 时命中并移除。 + +### 10.4.2 `aspaAssertions` + +数组元素每项: +- 必须:`customerAsn` +- 必须:`providerAsns`(ASN 数组) +- 可选:`comment` + +关键约束(draft-04): +- `customerAsn` 不得出现在 `providerAsns` 中 +- `providerAsns` 必须按升序排列 +- `providerAsns` 里的 ASN 必须唯一(无重复) + +语义补充(draft-04): +- `aspaAssertions` 仅用于“新增断言”,不构成隐式过滤(不会自动替代 `aspaFilters`)。 +- 在 RTRv2 输出阶段,新增的 ASPA 断言应加入 ASPA PDU 集合,并做去重。 + +## 10.5 应用语义(RFC 8416 Section 4) + +### 10.5.1 原子性 + +SLURM 应用必须是原子的: +- 要么完全不生效(等同未使用 SLURM) +- 要么完整按当前 SLURM 配置生效 + +### 10.5.2 处理顺序 + +在同一次计算中: +1. 先执行 `validationOutputFilters`(移除匹配验证结果) +2. 再追加 `locallyAddedAssertions` + +### 10.5.3 多文件 + +实现可以支持多个 SLURM 文件并行使用(取并集),但在启用前应检查断言重叠冲突;若存在冲突,整组文件应被拒绝。 + +## 10.6 最小可用示例(SLURM v2) -示例: ```json -"bgpsecFilters": [ - { - "asn": 64496, - "comment": "All keys for ASN" - }, - { - "SKI": "", - "comment": "Key matching Router SKI" - }, - { - "asn": 64497, - "SKI": "", - "comment": "Key for ASN 64497 matching Router SKI" - } -] +{ + "slurmVersion": 2, + "validationOutputFilters": { + "prefixFilters": [ + { + "prefix": "203.0.113.0/24", + "comment": "Filter a broken VRP from upstream" + } + ], + "bgpsecFilters": [], + "aspaFilters": [ + { + "customerAsn": 64496, + "comment": "Filter one customer ASPA" + } + ] + }, + "locallyAddedAssertions": { + "prefixAssertions": [ + { + "asn": 64496, + "prefix": "203.0.113.0/24", + "maxPrefixLength": 24, + "comment": "Local business exception" + } + ], + "bgpsecAssertions": [], + "aspaAssertions": [ + { + "customerAsn": 64496, + "providerAsns": [64497, 64498], + "comment": "Local ASPA assertion" + } + ] + } +} ``` -### prefixAssertions -`prefixAssertions`格式要求如下(RFC 8416 §3.4.1) -```text -The above is expressed as a value of the "prefixAssertions" member, -as an array of zero or more objects. Each object MUST contain one of -each of the following members: - o A "prefix" member, whose value is a string representing either an - IPv4 prefix (see Section 3.1 of [RFC4632]) or an IPv6 prefix (see - [RFC5952]). - o An "asn" member, whose value is a number. -In addition, each object MAY contain one of each of the following -members: - o A "maxPrefixLength" member, whose value is a number. - o A "comment" member, whose value is a string. -``` - -示例: -```json -"prefixAssertions": [ - { - "asn": 64496, - "prefix": "198.51.100.0/24", - "comment": "My other important route" - }, - { - "asn": 64496, - "prefix": "2001:DB8::/32", - "maxPrefixLength": 48, - "comment": "My other important de-aggregated routes" - } -] -``` - -### bgpsecAssertions -`bgpsecAssertions`格式要求如下(RFC 8416 §3.4.2) -```text -The above is expressed as a value of the "bgpsecAssertions" member, -as an array of zero or more objects. Each object MUST contain one -each of all of the following members: - o An "asn" member, whose value is a number. - o An "SKI" member, whose value is the Base64 encoding without - trailing ’=’ (Section 5 of [RFC4648]) of the certificate’s Subject - Key Identifier as described in Section 4.8.2 of [RFC6487] (This is - the value of the ASN.1 OCTET STRING without the ASN.1 tag or - length fields.) - o A "routerPublicKey" member, whose value is the Base64 encoding - without trailing ’=’ (Section 5 of [RFC4648]) of the equivalent to - the subjectPublicKeyInfo value of the router certificate’s public - key, as described in [RFC8208]. This is the full ASN.1 DER - encoding of the subjectPublicKeyInfo, including the ASN.1 tag and - length values of the subjectPublicKeyInfo SEQUENCE. -``` -示例: -```json -"bgpsecAssertions": [ - { - "asn": 64496, - "SKI": "", - "routerPublicKey": "", - "comment": "My known key for my important ASN" - } -] -``` - -## 10.3 抽象数据结构 - -### SLURM -| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | -|---------------------------|------------------------|---------|---------|---------------| -| slurm_version | number | SLURM版本 | 版本必须为1 | RFC 8416 §3.2 | -| validation_output_filters | ValidationOutputFilter | 过滤条件 | | | -| locally_added_assertions | LocallyAddedAssertions | 本地添加断言 | | | - -### ValidationOutputFilter -| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | -|----------------|-------------------|-----------|---------|---------------| -| prefix_filters | Vec | 前缀过滤 | 可以为空数组 | RFC 8416 §3.3 | -| bgpsec_filters | Vec | BGPsec过滤 | 可以为空数组 | RFC 8416 §3.3 | - -### LocallyAddedAssertions -| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | -|-------------------|----------------------|-----------|---------|---------------| -| prefix_assertions | Vec | 前缀断言 | 可以为空数组 | RFC 8416 §3.4 | -| bgpsec_assertions | Vec | BGPsec断言 | 可以为空数组 | RFC 8416 §3.4 | - -### PrefixFilter -| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | -|---------|--------|------|--------------------------------|-----------------| -| prefix | string | 前缀 | IPv4前缀或IPv6前缀,prefix和asn至少存在一个 | RFC 8416 §3.3.1 | -| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.3.1 | -| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.3.1 | - -### BgpsecFilter -| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | -|---------|--------|------|------------------|------------------| -| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.3.1 | -| ski | u8 | | 证书的SKI | RFC 8416 §3.3.1 | -| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.3.1 | - -### PrefixAssertion -| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | -|-------------------|--------|--------|---------------|-----------------| -| prefix | string | 前缀 | IPv4前缀或IPv6前缀 | RFC 8416 §3.4.1 | -| asn | number | ASN | | RFC 8416 §3.4.1 | -| max_prefix_length | number | 最大前缀长度 | 可选字段 | RFC 8416 §3.4.1 | -| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.4.1 | - - -### BgpsecAssertion -| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 | -|-------------------|--------|--------|------------------|-----------------| -| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.4.2 | -| ski | u8 | | 证书的SKI | RFC 8416 §3.4.2 | -| router_public_key | u8 | 证书的SKI | | RFC 8416 §3.4.2 | -| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.4.2 | - -> 注:BGPsec部分可以在第一版考虑先留空 - -## 10.4 规则 +## 10.7 参考文献 +- RFC 8416: https://www.rfc-editor.org/rfc/rfc8416.html +- draft-ietf-sidrops-aspa-slurm-04: https://www.ietf.org/archive/id/draft-ietf-sidrops-aspa-slurm-04.html diff --git a/specs/11_rtr.md b/specs/11_rtr.md index 8a6051a..9bb3eff 100644 --- a/specs/11_rtr.md +++ b/specs/11_rtr.md @@ -1,65 +1,158 @@ -# 11. RTR (The Resource Public Key Infrastructure (RPKI) to Router Protocol) +# 11. RTR(RPKI to Router Protocol) -## 11.1 Cache Server +## 11.1 目标与文档范围 -### 11.1.1 功能需求 +RTR 用于把 RP/Cache 已完成密码学验证的 RPKI 数据下发给路由器。 -- 支持Full Sync(Reset Query) -- 支持Incremental Sync(Serial Query) -- 支持多客户端并发 -- 支持Serial递增 -- 保留一定数量的delta -- 支持原子更新 +本文按以下规范整理: +- RFC 6810(RTR v0) +- RFC 8210(RTR v1,更新 RFC 6810) +- draft-ietf-sidrops-8210bis-25(RTR v2,草案) -### 11.1.2 架构设计 -采用一级缓存+二级缓存并存的方式。 +## 11.2 协议演进 -![img.png](img/img.png) +### 11.2.1 RFC 6810(v0) -其中,一级缓存为运行时缓存,主要职责: -- 存储当前完整的snapshot -- 历史Delta队列管理 -- Serial管理 -- RTR查询响应 +- 只定义 Prefix Origin 相关 payload(IPv4/IPv6 Prefix PDU)。 +- 主要 PDU:Serial Notify / Serial Query / Reset Query / Cache Response / Prefix / End of Data / Cache Reset / Error Report。 -二级缓存为持久化缓存,主要职责: -- snapshot持久化 -- 缓存重启后的快速恢复(snapshot和serial) -- 不参与实时查询 -- 异步写入 +### 11.2.2 RFC 8210(v1) -### 11.1.3 核心数据结构设计 +在 v0 基础上新增/强化: +- 新增 `Router Key PDU`(PDU Type 9,v1 可用,v0 保留)。 +- 强化协议版本协商与降级行为。 +- `End of Data` 在 v1 中携带 `Refresh/Retry/Expire` 三个计时参数。 -#### 11.1.3.1 总cache -```rust -struct RtrCache { - serial: AtomicU32, - snapshot: ArcSwap, - deltas: RwLock>>, - max_delta: usize, -} -``` +### 11.2.3 Version 2(草案) -#### 11.1.3.2 Snapshot -```rust -struct Snapshot { - origins: Vec, - router_keys: Vec, - aspas: Vec, - created_at: Instant, -} -``` +在 v1 基础上新增/强化: +- 新增 `ASPA PDU`(PDU Type 11,仅 v2)。 +- 新增 “Races, Ordering, and Transactions” 章节,要求缓存按规定顺序输出 payload 以降低路由器短暂误判。 +- 协议版本提升到 `2`。 +- 明确 PDU 最大长度上限为 64k(65535)。 -#### 11.1.3.3 Delta -```rust -struct Delta { - serial: u32, - announced: Vec, - withdrawn: Vec, -} -``` +## 11.3 PDU 与版本矩阵 +PDU 类型(按规范注册表): -## 11.2 Transport +| PDU Type | 名称 | v0 (RFC6810) | v1 (RFC8210) | v2 (8210bis-25) | +|---|---|---|---|---| +| 0 | Serial Notify | 支持 | 支持 | 支持 | +| 1 | Serial Query | 支持 | 支持 | 支持 | +| 2 | Reset Query | 支持 | 支持 | 支持 | +| 3 | Cache Response | 支持 | 支持 | 支持 | +| 4 | IPv4 Prefix | 支持 | 支持 | 支持 | +| 6 | IPv6 Prefix | 支持 | 支持 | 支持 | +| 7 | End of Data | 支持 | 支持(含计时参数) | 支持 | +| 8 | Cache Reset | 支持 | 支持 | 支持 | +| 9 | Router Key | 保留 | 支持 | 支持 | +| 10 | Error Report | 支持 | 支持 | 支持 | +| 11 | ASPA | 保留 | 保留 | 支持 | -初版实现RTR over TLS(可外网)和RTR over TCP(内网)两种方式。 \ No newline at end of file +通用字段约束: +- `Protocol Version`:8-bit。 +- `PDU Type`:8-bit。 +- `Session ID`:16-bit。 +- `Length`:32-bit。 +- 保留位(zero/reserved)发送必须为 0,接收时按规范处理。 + +## 11.4 关键 PDU 语义 + +### 11.4.1 Serial Notify(Type 0) + +- 由 Cache 主动发送,提示有新序列可拉取。 +- 是少数可不由 Router 请求触发的消息。 + +### 11.4.2 Reset Query(Type 2)与 Cache Response(Type 3) + +- Router 启动或失配时发 `Reset Query` 请求全量。 +- Cache 回复 `Cache Response`,随后发送全量 payload,最后 `End of Data`。 + +### 11.4.3 Serial Query(Type 1) + +- Router 持有上次 `Session ID + Serial` 时请求增量。 +- Cache 若可提供增量:返回变化集。 +- Cache 若无法从该 serial 补增量:返回 `Cache Reset`,要求 Router 走全量。 + +### 11.4.4 Prefix / Router Key / ASPA payload + +- `IPv4 Prefix`(Type 4)/ `IPv6 Prefix`(Type 6):表示 VRP 的 announce/withdraw。 +- `Router Key`(Type 9,v1+):表示 BGPsec Router Key 的 announce/withdraw。 +- `ASPA`(Type 11,v2 草案):表示 ASPA 数据单元的 announce/withdraw。 + +语义要点(v1 / v2 草案): +- 对同一 payload 键(如 Prefix 四元组、Router Key 三元组、ASPA customer 键)应维护清晰的替换/撤销关系。 +- Cache 负责把历史变化“合并简化”后再发给 Router,避免无意义抖动。 + +### 11.4.5 End of Data(Type 7) + +- 标识一次响应结束,并给出当前 serial。 +- v0:不含定时器字段。 +- v1/v2:携带 `Refresh Interval`、`Retry Interval`、`Expire Interval`。 + +## 11.5 协议时序 + +### 11.5.1 初始同步(Full Sync) + +1. Router 建连后发 `Reset Query`(带支持的协议版本)。 +2. Cache 回 `Cache Response`。 +3. Cache 按规范发送 payload 集合。 +4. Cache 发 `End of Data` 收尾。 + +### 11.5.2 增量同步(Incremental Sync) + +1. Router 发 `Serial Query(session_id, serial)`。 +2. Cache 若可增量,返回变化并以 `End of Data` 收尾。 +3. 若不可增量,返回 `Cache Reset`;Router 退回 Full Sync。 + +## 11.6 版本协商与降级 + +- Router 每次新连接必须由 `Reset Query` 或 `Serial Query` 启动,携带其协议版本。 +- 双方在协商完成后,本连接内版本固定。 +- 遇到不支持版本时,可按规范降级(例如 v1 对 v0、v2 对 v1/v0)或返回 `Unsupported Protocol Version` 后断开。 +- 协商期若收到 `Serial Notify`,Router 应按规范兼容处理(通常忽略,待协商完成)。 + +## 11.7 计时器与失效(v1/v2) + +`End of Data` 下发三个参数: +- `Refresh Interval`:多久后主动刷新。 +- `Retry Interval`:失败后重试间隔。 +- `Expire Interval`:本地数据最长可保留时长。 + +规范边界(RFC 8210): +- Refresh: 1 .. 86400(推荐 3600) +- Retry: 1 .. 7200(推荐 600) +- Expire: 600 .. 172800(推荐 7200) +- 且 `Expire` 必须大于 `Refresh` 和 `Retry`。 + +## 11.8 Version 2(草案)新增关注点 + +### 11.8.1 ASPA PDU + +- 新增 ASPA 传输能力(Type 11)。 +- 针对同一 customer ASN,Cache 需向 Router 提供一致且可替换的 ASPA 视图。 + +### 11.8.2 排序与事务 + +- 草案新增 race 条件说明(如前缀替换、撤销先后导致短暂误判)。 +- 对 Cache 输出 payload 的顺序提出约束。 +- 建议 Router 使用“事务式应用”(例如接收到完整响应后再切换生效)降低中间态影响。 + +## 11.9 传输与安全 + +规范定义可承载于多种传输: +- SSH +- TLS +- TCP MD5 +- TCP-AO + +安全原则: +- Router 与 Cache 之间必须建立可信关系。 +- 需要完整性/机密性时优先使用具备认证与加密能力的传输。 +- 若使用普通 TCP,部署上应限制在可信受控网络中。 + +## 11.10 参考文献 + +- RFC 6810: https://www.rfc-editor.org/rfc/rfc6810.html +- RFC 8210: https://www.rfc-editor.org/rfc/rfc8210.html +- draft-ietf-sidrops-8210bis-25: https://www.ietf.org/archive/id/draft-ietf-sidrops-8210bis-25.html diff --git a/src/bin/rtr_debug_client/main.rs b/src/bin/rtr_debug_client/main.rs index d7f03ce..dffc315 100644 --- a/src/bin/rtr_debug_client/main.rs +++ b/src/bin/rtr_debug_client/main.rs @@ -7,18 +7,16 @@ use rustls::{ClientConfig as RustlsClientConfig, RootCertStore}; use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use tokio::io::{self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, WriteHalf}; use tokio::net::TcpStream; -use tokio::time::{timeout, Duration, Instant}; +use tokio::time::{Duration, Instant, timeout}; use tokio_rustls::TlsConnector; -mod wire; mod pretty; mod protocol; +mod wire; -use crate::wire::{read_pdu, send_reset_query, send_serial_query}; -use crate::pretty::{ - parse_end_of_data_info, parse_serial_notify_serial, print_pdu, print_raw_pdu, -}; +use crate::pretty::{parse_end_of_data_info, parse_serial_notify_serial, print_pdu, print_raw_pdu}; use crate::protocol::{PduHeader, PduType, QueryMode}; +use crate::wire::{read_pdu, send_reset_query, send_serial_query}; const DEFAULT_READ_TIMEOUT_SECS: u64 = 30; const DEFAULT_POLL_INTERVAL_SECS: u64 = 600; @@ -38,7 +36,10 @@ async fn main() -> io::Result<()> { println!("transport: {}", config.transport.describe()); println!("version : {}", config.version); println!("timeout : {}s", config.read_timeout_secs); - println!("poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs); + println!( + "poll : {}s (default before EndOfData refresh is known)", + config.default_poll_secs + ); println!("keep-after-error: {}", config.keep_after_error); match &config.mode { QueryMode::Reset => { @@ -72,11 +73,7 @@ async fn main() -> io::Result<()> { } Err(err) => { let delay = state.reconnect_delay_secs(); - eprintln!( - "connect failed: {}. retry after {}s", - err, - delay - ); + eprintln!("connect failed: {}. retry after {}s", err, delay); tokio::time::sleep(Duration::from_secs(delay)).await; } } @@ -171,10 +168,7 @@ async fn main() -> io::Result<()> { if reconnect { let delay = state.reconnect_delay_secs(); state.current_session_id = None; - println!( - "[reconnect] transport disconnected, retry after {}s", - delay - ); + println!("[reconnect] transport disconnected, retry after {}s", delay); tokio::time::sleep(Duration::from_secs(delay)).await; } } @@ -189,8 +183,7 @@ async fn send_resume_query( (Some(session_id), Some(serial)) => { println!( "reconnected, send Serial Query with session_id={}, serial={}", - session_id, - serial + session_id, serial ); send_serial_query(writer, state.version, session_id, serial).await?; } @@ -294,26 +287,20 @@ async fn handle_incoming_pdu( println!(); println!( "[notify] received Serial Notify: session_id={}, notify_serial={:?}", - notify_session_id, - notify_serial + notify_session_id, notify_serial ); match (state.session_id, state.serial, notify_serial) { (Some(current_session_id), Some(current_serial), Some(_new_serial)) - if current_session_id == notify_session_id => - { - println!( - "received Serial Notify for current session {}, send Serial Query with serial {}", - current_session_id, current_serial - ); - send_serial_query( - writer, - state.version, - current_session_id, - current_serial, - ) - .await?; - } + if current_session_id == notify_session_id => + { + println!( + "received Serial Notify for current session {}, send Serial Query with serial {}", + current_session_id, current_serial + ); + send_serial_query(writer, state.version, current_session_id, current_serial) + .await?; + } _ => { println!( @@ -366,10 +353,7 @@ async fn handle_incoming_pdu( Ok(()) } -async fn handle_poll_tick( - writer: &mut ClientWriter, - state: &mut ClientState, -) -> io::Result<()> { +async fn handle_poll_tick(writer: &mut ClientWriter, state: &mut ClientState) -> io::Result<()> { println!(); println!( "[auto-poll] timer fired (interval={}s)", @@ -422,23 +406,21 @@ async fn handle_console_command( state.schedule_next_poll(); } - ["serial"] => { - match (state.session_id, state.serial) { - (Some(session_id), Some(serial)) => { - println!( - "manual command: send Serial Query with current state: session_id={}, serial={}", - session_id, serial - ); - send_serial_query(writer, state.version, session_id, serial).await?; - state.schedule_next_poll(); - } - _ => { - println!( - "manual command failed: current session_id/serial not available, use `reset` or `serial `" - ); - } + ["serial"] => match (state.session_id, state.serial) { + (Some(session_id), Some(serial)) => { + println!( + "manual command: send Serial Query with current state: session_id={}, serial={}", + session_id, serial + ); + send_serial_query(writer, state.version, session_id, serial).await?; + state.schedule_next_poll(); } - } + _ => { + println!( + "manual command failed: current session_id/serial not available, use `reset` or `serial `" + ); + } + }, ["serial", session_id, serial] => { let session_id = match session_id.parse::() { @@ -493,7 +475,10 @@ async fn handle_console_command( "current effective poll interval: {}s", state.effective_poll_secs() ); - println!("poll interval source : {}", state.poll_interval_source()); + println!( + "poll interval source : {}", + state.poll_interval_source() + ); println!("stored refresh hint : {:?}", state.refresh); println!("default poll interval : {}s", state.default_poll_secs); println!("last_error_code : {:?}", state.last_error_code); @@ -626,17 +611,20 @@ impl ClientState { fn effective_poll_secs(&self) -> u64 { if self.should_prefer_retry_poll() { - self.retry - .map(|v| v as u64) - .unwrap_or_else(|| self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs)) + self.retry.map(|v| v as u64).unwrap_or_else(|| { + self.refresh + .map(|v| v as u64) + .unwrap_or(self.default_poll_secs) + }) } else { - self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs) + self.refresh + .map(|v| v as u64) + .unwrap_or(self.default_poll_secs) } } fn schedule_next_poll(&mut self) { - self.next_poll_deadline = - Instant::now() + Duration::from_secs(self.effective_poll_secs()); + self.next_poll_deadline = Instant::now() + Duration::from_secs(self.effective_poll_secs()); } fn pause_auto_poll(&mut self) { @@ -728,7 +716,10 @@ impl Config { } "--server-name" => { let name = args.next().ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidInput, "--server-name requires a value") + io::Error::new( + io::ErrorKind::InvalidInput, + "--server-name requires a value", + ) })?; ensure_tls_config(&mut transport)?.server_name = Some(name); } @@ -805,10 +796,7 @@ impl Config { let serial = positional .next() .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "serial mode requires serial", - ) + io::Error::new(io::ErrorKind::InvalidInput, "serial mode requires serial") })? .parse::() .map_err(|e| { @@ -949,9 +937,15 @@ async fn connect_tls_stream(addr: &str, tls: &TlsConfig) -> io::Result io::Result { if added == 0 { return Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("no valid CA certificates found in {}", ca_cert_path.display()), + format!( + "no valid CA certificates found in {}", + ca_cert_path.display() + ), )); } diff --git a/src/bin/rtr_debug_client/pretty.rs b/src/bin/rtr_debug_client/pretty.rs index 401caab..6610721 100644 --- a/src/bin/rtr_debug_client/pretty.rs +++ b/src/bin/rtr_debug_client/pretty.rs @@ -1,9 +1,8 @@ use std::net::{Ipv4Addr, Ipv6Addr}; use crate::protocol::{ - flag_meaning, hex_bytes, PduHeader, PduType, ASPA_FIXED_BODY_LEN, - END_OF_DATA_V0_BODY_LEN, END_OF_DATA_V1_BODY_LEN, IPV4_PREFIX_BODY_LEN, - IPV6_PREFIX_BODY_LEN, ROUTER_KEY_FIXED_BODY_LEN, + ASPA_FIXED_BODY_LEN, END_OF_DATA_V0_BODY_LEN, END_OF_DATA_V1_BODY_LEN, IPV4_PREFIX_BODY_LEN, + IPV6_PREFIX_BODY_LEN, PduHeader, PduType, ROUTER_KEY_FIXED_BODY_LEN, flag_meaning, hex_bytes, }; pub fn print_pdu(header: &PduHeader, body: &[u8]) { @@ -143,8 +142,7 @@ fn print_error_report(header: &PduHeader, body: &[u8]) { return; } - let encapsulated_len = - u32::from_be_bytes([body[0], body[1], body[2], body[3]]) as usize; + let encapsulated_len = u32::from_be_bytes([body[0], body[1], body[2], body[3]]) as usize; if body.len() < 4 + encapsulated_len + 4 { println!("invalid ErrorReport: truncated encapsulated PDU"); diff --git a/src/bin/rtr_debug_client/wire.rs b/src/bin/rtr_debug_client/wire.rs index 346cb01..48a6766 100644 --- a/src/bin/rtr_debug_client/wire.rs +++ b/src/bin/rtr_debug_client/wire.rs @@ -2,9 +2,7 @@ use std::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use crate::protocol::{ - PduHeader, PduType, RawPdu, HEADER_LEN, MAX_PDU_LEN, SERIAL_QUERY_LEN, -}; +use crate::protocol::{HEADER_LEN, MAX_PDU_LEN, PduHeader, PduType, RawPdu, SERIAL_QUERY_LEN}; pub async fn send_reset_query(stream: &mut S, version: u8) -> io::Result<()> where @@ -56,10 +54,7 @@ where if header.length < HEADER_LEN as u32 { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!( - "invalid PDU length {} < {}", - header.length, HEADER_LEN - ), + format!("invalid PDU length {} < {}", header.length, HEADER_LEN), )); } @@ -78,4 +73,4 @@ where stream.read_exact(&mut body).await?; Ok(RawPdu { header, body }) -} \ No newline at end of file +} diff --git a/src/bin/slurm_apply_client.rs b/src/bin/slurm_apply_client.rs new file mode 100644 index 0000000..eabc9c6 --- /dev/null +++ b/src/bin/slurm_apply_client.rs @@ -0,0 +1,168 @@ +use std::env; +use std::fs::File; +use std::path::PathBuf; + +use anyhow::{Context, Result, anyhow}; +use serde::Serialize; + +use rpki::rtr::payload::Payload; +use rpki::slurm::file::SlurmFile; +use rpki::source::ccr::{ + find_latest_ccr_file, load_ccr_snapshot_from_file, snapshot_to_payloads_with_options, +}; + +#[derive(Debug)] +struct Cli { + ccr_path: PathBuf, + slurm_path: PathBuf, + strict_ccr: bool, + dump_payloads: bool, +} + +#[derive(Debug, Serialize)] +struct Output { + ccr_path: String, + slurm_path: String, + produced_at: Option, + slurm_version: u32, + input_payload_count: usize, + input_vrp_count: usize, + input_vap_count: usize, + output_payload_count: usize, + output_vrp_count: usize, + output_vap_count: usize, + invalid_vrps: Vec, + invalid_vaps: Vec, + sample_output_aspa_customers: Vec, + payloads: Option>, +} + +fn main() -> Result<()> { + let cli = parse_args(env::args().skip(1))?; + + let snapshot = load_ccr_snapshot_from_file(&cli.ccr_path) + .with_context(|| format!("failed to load CCR snapshot: {}", cli.ccr_path.display()))?; + let slurm = load_slurm(&cli.slurm_path)?; + + let conversion = snapshot_to_payloads_with_options(&snapshot, cli.strict_ccr)?; + let payloads = slurm.apply(&conversion.payloads); + let (input_vrp_count, input_vap_count) = count_vrps_and_vaps(&conversion.payloads); + let (output_vrp_count, output_vap_count) = count_vrps_and_vaps(&payloads); + + let output = Output { + ccr_path: cli.ccr_path.display().to_string(), + slurm_path: cli.slurm_path.display().to_string(), + produced_at: snapshot.produced_at.clone(), + slurm_version: slurm.version().as_u32(), + input_payload_count: conversion.payloads.len(), + input_vrp_count, + input_vap_count, + output_payload_count: payloads.len(), + output_vrp_count, + output_vap_count, + invalid_vrps: conversion.invalid_vrps, + invalid_vaps: conversion.invalid_vaps, + sample_output_aspa_customers: sample_aspa_customers(&payloads, 8), + payloads: cli.dump_payloads.then_some(payloads), + }; + + println!("{}", serde_json::to_string_pretty(&output)?); + Ok(()) +} + +fn load_slurm(path: &PathBuf) -> Result { + let file = File::open(path) + .with_context(|| format!("failed to open SLURM file: {}", path.display()))?; + SlurmFile::from_reader(file) + .with_context(|| format!("failed to parse SLURM file: {}", path.display())) +} + +fn parse_args(args: impl Iterator) -> Result { + let mut strict_ccr = false; + let mut dump_payloads = false; + let mut positionals = Vec::new(); + + for arg in args { + match arg.as_str() { + "--strict-ccr" => strict_ccr = true, + "--dump-payloads" => dump_payloads = true, + "-h" | "--help" => { + print_help(); + std::process::exit(0); + } + _ if arg.starts_with('-') => { + return Err(anyhow!("unknown option: {}", arg)); + } + _ => positionals.push(arg), + } + } + + if positionals.is_empty() { + return Ok(Cli { + ccr_path: find_latest_ccr_file("data") + .context("failed to find latest .ccr in ./data for default run")?, + slurm_path: PathBuf::from("data/example.slurm"), + strict_ccr, + dump_payloads, + }); + } + + if positionals.len() != 2 { + print_help(); + return Err(anyhow!( + "expected: slurm_apply_client " + )); + } + + Ok(Cli { + ccr_path: PathBuf::from(&positionals[0]), + slurm_path: PathBuf::from(&positionals[1]), + strict_ccr, + dump_payloads, + }) +} + +fn print_help() { + eprintln!( + "Usage: cargo run --bin slurm_apply_client -- [--strict-ccr] [--dump-payloads] " + ); + eprintln!(); + eprintln!("Reads a CCR snapshot, converts it into payloads, applies SLURM, and prints JSON."); + eprintln!( + "When no arguments are provided, it defaults to the latest .ccr under ./data and ./data/example.slurm." + ); + eprintln!("Use --dump-payloads to include the full payload list in the JSON output."); +} + +fn count_vrps_and_vaps(payloads: &[Payload]) -> (usize, usize) { + let mut vrps = 0; + let mut vaps = 0; + + for payload in payloads { + match payload { + Payload::RouteOrigin(_) => vrps += 1, + Payload::Aspa(_) => vaps += 1, + Payload::RouterKey(_) => {} + } + } + + (vrps, vaps) +} + +fn sample_aspa_customers(payloads: &[Payload], limit: usize) -> Vec { + let mut customers = Vec::new(); + + for payload in payloads { + if let Payload::Aspa(aspa) = payload { + let customer = aspa.customer_asn().into_u32(); + if !customers.contains(&customer) { + customers.push(customer); + if customers.len() == limit { + break; + } + } + } + } + + customers +} diff --git a/src/lib.rs b/src/lib.rs index 5ef2acb..88e8e53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ pub mod data_model; -mod slurm; +pub mod slurm; pub mod rtr; +pub mod source; diff --git a/src/main.rs b/src/main.rs index 19c8073..41af49e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,15 +3,15 @@ use std::net::SocketAddr; use std::sync::{Arc, RwLock}; use std::time::Duration; -use anyhow::{anyhow, Result}; +use anyhow::{Result, anyhow}; use tokio::task::JoinHandle; use tracing::{info, warn}; -use rpki::rtr::ccr::{find_latest_ccr_file, load_ccr_payloads_from_file_with_options, load_ccr_snapshot_from_file}; use rpki::rtr::cache::{RtrCache, SharedRtrCache}; use rpki::rtr::payload::Timing; use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService}; use rpki::rtr::store::RtrStore; +use rpki::source::pipeline::{PayloadLoadConfig, load_payloads_from_latest_sources}; #[derive(Debug, Clone)] struct AppConfig { @@ -21,6 +21,7 @@ struct AppConfig { db_path: String, ccr_dir: String, + slurm_dir: Option, tls_cert_path: String, tls_key_path: String, tls_client_ca_path: String, @@ -42,6 +43,7 @@ impl Default for AppConfig { db_path: "./rtr-db".to_string(), ccr_dir: "./data".to_string(), + slurm_dir: None, tls_cert_path: "./certs/server.crt".to_string(), tls_key_path: "./certs/server.key".to_string(), tls_client_ca_path: "./certs/client-ca.crt".to_string(), @@ -85,6 +87,14 @@ impl AppConfig { if let Some(value) = env_var("RPKI_RTR_CCR_DIR")? { config.ccr_dir = value; } + if let Some(value) = env_var("RPKI_RTR_SLURM_DIR")? { + let value = value.trim(); + config.slurm_dir = if value.is_empty() { + None + } else { + Some(value.to_string()) + }; + } if let Some(value) = env_var("RPKI_RTR_TLS_CERT_PATH")? { config.tls_cert_path = value; } @@ -104,8 +114,7 @@ impl AppConfig { parse_bool(&value, "RPKI_RTR_PRUNE_DELTA_BY_SNAPSHOT_SIZE")?; } if let Some(value) = env_var("RPKI_RTR_STRICT_CCR_VALIDATION")? { - config.strict_ccr_validation = - parse_bool(&value, "RPKI_RTR_STRICT_CCR_VALIDATION")?; + config.strict_ccr_validation = parse_bool(&value, "RPKI_RTR_STRICT_CCR_VALIDATION")?; } if let Some(value) = env_var("RPKI_RTR_REFRESH_INTERVAL_SECS")? { let secs: u64 = value.parse().map_err(|err| { @@ -118,9 +127,9 @@ impl AppConfig { config.refresh_interval = Duration::from_secs(secs); } if let Some(value) = env_var("RPKI_RTR_MAX_CONNECTIONS")? { - config.service_config.max_connections = value.parse().map_err(|err| { - anyhow!("invalid RPKI_RTR_MAX_CONNECTIONS '{}': {}", value, err) - })?; + config.service_config.max_connections = value + .parse() + .map_err(|err| anyhow!("invalid RPKI_RTR_MAX_CONNECTIONS '{}': {}", value, err))?; } if let Some(value) = env_var("RPKI_RTR_NOTIFY_QUEUE_SIZE")? { config.service_config.notify_queue_size = value.parse().map_err(|err| { @@ -184,12 +193,17 @@ fn open_store(config: &AppConfig) -> Result { } fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result { + let payload_load_config = PayloadLoadConfig { + ccr_dir: config.ccr_dir.clone(), + slurm_dir: config.slurm_dir.clone(), + strict_ccr_validation: config.strict_ccr_validation, + }; let initial_cache = RtrCache::default().init( store, config.max_delta, config.prune_delta_by_snapshot_size, Timing::default(), - || load_payloads_from_latest_ccr(&config.ccr_dir, config.strict_ccr_validation), + || load_payloads_from_latest_sources(&payload_load_config), )?; let shared_cache: SharedRtrCache = Arc::new(RwLock::new(initial_cache)); @@ -232,8 +246,11 @@ fn spawn_refresh_task( notifier: RtrNotifier, ) -> JoinHandle<()> { let refresh_interval = config.refresh_interval; - let ccr_dir = config.ccr_dir.clone(); - let strict_ccr_validation = config.strict_ccr_validation; + let payload_load_config = PayloadLoadConfig { + ccr_dir: config.ccr_dir.clone(), + slurm_dir: config.slurm_dir.clone(), + strict_ccr_validation: config.strict_ccr_validation, + }; tokio::spawn(async move { let mut interval = tokio::time::interval(refresh_interval); @@ -241,7 +258,7 @@ fn spawn_refresh_task( loop { interval.tick().await; - match load_payloads_from_latest_ccr(&ccr_dir, strict_ccr_validation) { + match load_payloads_from_latest_sources(&payload_load_config) { Ok(payloads) => { let payload_count = payloads.len(); let updated = { @@ -261,7 +278,7 @@ fn spawn_refresh_task( if new_serial != old_serial { info!( "RTR cache refresh applied: ccr_dir={}, payload_count={}, old_serial={}, new_serial={}", - ccr_dir, + payload_load_config.ccr_dir, payload_count, old_serial, new_serial @@ -270,9 +287,7 @@ fn spawn_refresh_task( } else { info!( "RTR cache refresh found no change: ccr_dir={}, payload_count={}, serial={}", - ccr_dir, - payload_count, - old_serial + payload_load_config.ccr_dir, payload_count, old_serial ); false } @@ -290,7 +305,10 @@ fn spawn_refresh_task( } } Err(err) => { - warn!("failed to reload CCR payloads from {}: {:?}", ccr_dir, err); + warn!( + "failed to reload CCR/SLURM payloads from {}: {:?}", + payload_load_config.ccr_dir, err + ); } } } @@ -317,16 +335,17 @@ fn log_startup_config(config: &AppConfig) { } info!("ccr_dir={}", config.ccr_dir); + info!( + "slurm_dir={}", + config.slurm_dir.as_deref().unwrap_or("disabled") + ); info!("max_delta={}", config.max_delta); info!("strict_ccr_validation={}", config.strict_ccr_validation); info!( "refresh_interval_secs={}", config.refresh_interval.as_secs() ); - info!( - "max_connections={}", - config.service_config.max_connections - ); + info!("max_connections={}", config.service_config.max_connections); info!( "notify_queue_size={}", config.service_config.notify_queue_size @@ -372,50 +391,3 @@ fn parse_bool(value: &str, name: &str) -> Result { _ => Err(anyhow!("invalid {} '{}': expected boolean", name, value)), } } - -fn load_payloads_from_latest_ccr( - ccr_dir: &str, - strict_ccr_validation: bool, -) -> Result> { - let latest = find_latest_ccr_file(ccr_dir)?; - let snapshot = load_ccr_snapshot_from_file(&latest)?; - let vrp_count = snapshot.vrps.len(); - let vap_count = snapshot.vaps.len(); - let produced_at = snapshot.produced_at.clone(); - let conversion = load_ccr_payloads_from_file_with_options(&latest, strict_ccr_validation)?; - let payloads = conversion.payloads; - - if !conversion.invalid_vrps.is_empty() { - warn!( - "CCR load skipped invalid VRPs: file={}, skipped={}, samples={:?}", - latest.display(), - conversion.invalid_vrps.len(), - sample_messages(&conversion.invalid_vrps) - ); - } - - if !conversion.invalid_vaps.is_empty() { - warn!( - "CCR load skipped invalid VAPs/ASPAs: file={}, skipped={}, samples={:?}", - latest.display(), - conversion.invalid_vaps.len(), - sample_messages(&conversion.invalid_vaps) - ); - } - - info!( - "loaded latest CCR snapshot: file={}, produced_at={:?}, vrp_count={}, vap_count={}, payload_count={}, strict_ccr_validation={}", - latest.display(), - produced_at, - vrp_count, - vap_count, - payloads.len(), - strict_ccr_validation - ); - - Ok(payloads) -} - -fn sample_messages(messages: &[String]) -> Vec<&str> { - messages.iter().take(3).map(String::as_str).collect() -} diff --git a/src/rtr/cache/core.rs b/src/rtr/cache/core.rs index 817c0c0..1059453 100644 --- a/src/rtr/cache/core.rs +++ b/src/rtr/cache/core.rs @@ -1,14 +1,14 @@ -use std::collections::{BTreeMap, VecDeque}; -use std::cmp::Ordering; -use std::sync::Arc; use anyhow::Result; use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{BTreeMap, VecDeque}; +use std::sync::Arc; use tracing::{debug, info, warn}; use crate::rtr::payload::{Payload, Timing}; use super::model::{Delta, DualTime, Snapshot}; -use super::ordering::{change_key, ChangeKey}; +use super::ordering::{ChangeKey, change_key}; const SERIAL_HALF_RANGE: u32 = 1 << 31; @@ -166,9 +166,7 @@ impl RtrCacheBuilder { let serial = self.serial.unwrap_or(0); let created_at = self.created_at.unwrap_or_else(|| now.clone()); let availability = self.availability.unwrap_or(CacheAvailability::Ready); - let session_ids = self - .session_ids - .unwrap_or_else(SessionIds::random_distinct); + let session_ids = self.session_ids.unwrap_or_else(SessionIds::random_distinct); RtrCache { availability, @@ -235,8 +233,7 @@ impl RtrCache { self.serial = self.serial.wrapping_add(1); debug!( "RTR cache advanced serial: old_serial={}, new_serial={}", - old, - self.serial + old, self.serial ); self.serial } @@ -251,9 +248,7 @@ impl RtrCache { let snapshot_wire_size = estimate_snapshot_payload_wire_size(&self.snapshot); let mut cumulative_delta_wire_size = estimate_delta_window_payload_wire_size(&self.deltas); - while !self.deltas.is_empty() - && cumulative_delta_wire_size >= snapshot_wire_size - { + while !self.deltas.is_empty() && cumulative_delta_wire_size >= snapshot_wire_size { if let Some(oldest) = self.deltas.pop_front() { dropped_serials.push(oldest.serial()); cumulative_delta_wire_size = @@ -262,9 +257,7 @@ impl RtrCache { } debug!( "RTR cache delta-size pruning evaluated: snapshot_wire_size={}, cumulative_delta_wire_size={}, dropped_serials={:?}", - snapshot_wire_size, - cumulative_delta_wire_size, - dropped_serials + snapshot_wire_size, cumulative_delta_wire_size, dropped_serials ); } debug!( @@ -292,7 +285,10 @@ impl RtrCache { } } - pub(super) fn apply_update(&mut self, new_payloads: Vec) -> Result> { + pub(super) fn apply_update( + &mut self, + new_payloads: Vec, + ) -> Result> { self.last_update_begin = DualTime::now(); info!( "RTR cache applying update: availability={:?}, current_serial={}, incoming_payloads={}", @@ -319,14 +315,15 @@ impl RtrCache { self.last_update_end = DualTime::now(); if !changed { - debug!("RTR cache update produced empty snapshot but cache was already unavailable; no state change"); + debug!( + "RTR cache update produced empty snapshot but cache was already unavailable; no state change" + ); return Ok(None); } info!( "RTR cache update cleared usable data and marked cache unavailable: serial={}, session_ids={:?}", - self.serial, - self.session_ids + self.serial, self.session_ids ); return Ok(Some(AppliedUpdate { @@ -349,8 +346,7 @@ impl RtrCache { self.last_update_end = DualTime::now(); debug!( "RTR cache update detected identical snapshot content: serial={}, session_ids={:?}", - self.serial, - self.session_ids + self.serial, self.session_ids ); return Ok(None); } @@ -455,8 +451,7 @@ impl RtrCache { if client_serial == self.serial { debug!( "RTR cache delta query is already up to date: client_serial={}, cache_serial={}", - client_serial, - self.serial + client_serial, self.serial ); return SerialResult::UpToDate; } @@ -467,8 +462,7 @@ impl RtrCache { ) { warn!( "RTR cache delta query requires reset due to invalid/newer client serial: client_serial={}, cache_serial={}", - client_serial, - self.serial + client_serial, self.serial ); return SerialResult::ResetRequired; } @@ -489,8 +483,7 @@ impl RtrCache { if deltas.is_empty() { debug!( "RTR cache delta query resolved to no deltas: client_serial={}, cache_serial={}", - client_serial, - self.serial + client_serial, self.serial ); return SerialResult::UpToDate; } @@ -633,7 +626,11 @@ fn estimate_payload_wire_size(payload: &Payload, announce: bool) -> usize { }, Payload::RouterKey(key) => 8 + 20 + 4 + key.spki().len(), Payload::Aspa(aspa) => { - let providers = if announce { aspa.provider_asns().len() } else { 0 }; + let providers = if announce { + aspa.provider_asns().len() + } else { + 0 + }; 8 + 4 + providers * 4 } } diff --git a/src/rtr/cache/model.rs b/src/rtr/cache/model.rs index 6f9c41b..5cee060 100644 --- a/src/rtr/cache/model.rs +++ b/src/rtr/cache/model.rs @@ -195,7 +195,12 @@ impl Snapshot { } if !self.same_aspas(new_snapshot) { - diff_aspas(&self.aspas, &new_snapshot.aspas, &mut announced, &mut withdrawn); + diff_aspas( + &self.aspas, + &new_snapshot.aspas, + &mut announced, + &mut withdrawn, + ); } (announced, withdrawn) @@ -206,9 +211,8 @@ impl Snapshot { } pub fn payloads(&self) -> Vec { - let mut v = Vec::with_capacity( - self.origins.len() + self.router_keys.len() + self.aspas.len(), - ); + let mut v = + Vec::with_capacity(self.origins.len() + self.router_keys.len() + self.aspas.len()); v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin)); v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey)); @@ -268,9 +272,7 @@ impl Snapshot { } pub fn is_empty(&self) -> bool { - self.origins.is_empty() - && self.router_keys.is_empty() - && self.aspas.is_empty() + self.origins.is_empty() && self.router_keys.is_empty() && self.aspas.is_empty() } } diff --git a/src/rtr/cache/ordering.rs b/src/rtr/cache/ordering.rs index 75b7392..08d2879 100644 --- a/src/rtr/cache/ordering.rs +++ b/src/rtr/cache/ordering.rs @@ -64,8 +64,18 @@ enum PayloadPduType { #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] pub(crate) enum RouteOriginKey { - V4 { addr: u32, plen: u8, mlen: u8, asn: u32 }, - V6 { addr: u128, plen: u8, mlen: u8, asn: u32 }, + V4 { + addr: u32, + plen: u8, + mlen: u8, + asn: u32, + }, + V6 { + addr: u128, + plen: u8, + mlen: u8, + asn: u32, + }, } pub(crate) fn change_key(payload: &Payload) -> ChangeKey { @@ -287,7 +297,11 @@ fn payload_brief(payload: &Payload) -> String { match payload { Payload::RouteOrigin(origin) => format!( "{} prefix {:?}/{} max={} asn={}", - if route_origin_is_ipv4(origin) { "IPv4" } else { "IPv6" }, + if route_origin_is_ipv4(origin) { + "IPv4" + } else { + "IPv6" + }, origin.prefix().address, origin.prefix().prefix_length, origin.max_length(), diff --git a/src/rtr/cache/store.rs b/src/rtr/cache/store.rs index 47396a4..d30fc01 100644 --- a/src/rtr/cache/store.rs +++ b/src/rtr/cache/store.rs @@ -18,12 +18,9 @@ impl RtrCache { timing: Timing, file_loader: impl Fn() -> Result>, ) -> Result { - if let Some(cache) = try_restore_from_store( - store, - max_delta, - prune_delta_by_snapshot_size, - timing, - )? { + if let Some(cache) = + try_restore_from_store(store, max_delta, prune_delta_by_snapshot_size, timing)? + { tracing::info!( "RTR cache restored from store: availability={:?}, session_ids={:?}, serial={}, snapshot(route_origins={}, router_keys={}, aspas={})", cache.availability(), diff --git a/src/rtr/error_type.rs b/src/rtr/error_type.rs index 95eaaa9..f62a02f 100644 --- a/src/rtr/error_type.rs +++ b/src/rtr/error_type.rs @@ -19,7 +19,6 @@ pub enum ErrorCode { } impl ErrorCode { - #[inline] pub fn as_u16(self) -> u16 { self as u16 @@ -27,41 +26,29 @@ impl ErrorCode { pub fn description(self) -> &'static str { match self { - ErrorCode::CorruptData => - "Corrupt Data", + ErrorCode::CorruptData => "Corrupt Data", - ErrorCode::InternalError => - "Internal Error", + ErrorCode::InternalError => "Internal Error", - ErrorCode::NoDataAvailable => - "No Data Available", + ErrorCode::NoDataAvailable => "No Data Available", - ErrorCode::InvalidRequest => - "Invalid Request", + ErrorCode::InvalidRequest => "Invalid Request", - ErrorCode::UnsupportedProtocolVersion => - "Unsupported Protocol Version", + ErrorCode::UnsupportedProtocolVersion => "Unsupported Protocol Version", - ErrorCode::UnsupportedPduType => - "Unsupported PDU Type", + ErrorCode::UnsupportedPduType => "Unsupported PDU Type", - ErrorCode::WithdrawalOfUnknownRecord => - "Withdrawal of Unknown Record", + ErrorCode::WithdrawalOfUnknownRecord => "Withdrawal of Unknown Record", - ErrorCode::DuplicateAnnouncement => - "Duplicate Announcement Received", + ErrorCode::DuplicateAnnouncement => "Duplicate Announcement Received", - ErrorCode::UnexpectedProtocolVersion => - "Unexpected Protocol Version", + ErrorCode::UnexpectedProtocolVersion => "Unexpected Protocol Version", - ErrorCode::AspaProviderListError => - "ASPA Provider List Error", + ErrorCode::AspaProviderListError => "ASPA Provider List Error", - ErrorCode::TransportFailed => - "Transport Failed", + ErrorCode::TransportFailed => "Transport Failed", - ErrorCode::OrderingError => - "Ordering Error", + ErrorCode::OrderingError => "Ordering Error", } } } @@ -90,9 +77,6 @@ impl TryFrom for ErrorCode { impl fmt::Display for ErrorCode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} ({})", - self.description(), - *self as u16 - ) + write!(f, "{} ({})", self.description(), *self as u16) } } diff --git a/src/rtr/loader.rs b/src/rtr/loader.rs index ce37bdb..06349a6 100644 --- a/src/rtr/loader.rs +++ b/src/rtr/loader.rs @@ -3,7 +3,7 @@ use std::net::IpAddr; use std::path::Path; use std::str::FromStr; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result, anyhow}; use crate::data_model::resources::as_resources::Asn; use crate::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; @@ -131,10 +131,9 @@ pub fn parse_vrp_line(line: &str) -> Result { } let prefix_part = parts[0]; - let max_len = u8::from_str(parts[1]) - .with_context(|| format!("invalid max_len: {}", parts[1]))?; - let asn = u32::from_str(parts[2]) - .with_context(|| format!("invalid asn: {}", parts[2]))?; + let max_len = + u8::from_str(parts[1]).with_context(|| format!("invalid max_len: {}", parts[1]))?; + let asn = u32::from_str(parts[2]).with_context(|| format!("invalid asn: {}", parts[2]))?; let (addr_str, prefix_len_str) = prefix_part .split_once('/') @@ -164,14 +163,13 @@ pub fn parse_aspa_line(line: &str) -> Result { )); } - let customer_asn = u32::from_str(parts[0]) - .with_context(|| format!("invalid customer_asn: {}", parts[0]))?; + let customer_asn = + u32::from_str(parts[0]).with_context(|| format!("invalid customer_asn: {}", parts[0]))?; let provider_asns = parts[1] .split_whitespace() .map(|provider| { - u32::from_str(provider) - .with_context(|| format!("invalid provider_asn: {}", provider)) + u32::from_str(provider).with_context(|| format!("invalid provider_asn: {}", provider)) }) .collect::>>()?; @@ -186,23 +184,18 @@ pub fn parse_aspa_line(line: &str) -> Result { pub fn parse_router_key_line(line: &str) -> Result { let parts: Vec<_> = line.split(',').map(|s| s.trim()).collect(); if parts.len() != 3 { - return Err(anyhow!( - "expected format: ,," - )); + return Err(anyhow!("expected format: ,,")); } - let ski_vec = decode_hex(parts[0]) - .with_context(|| format!("invalid SKI hex: {}", parts[0]))?; + let ski_vec = decode_hex(parts[0]).with_context(|| format!("invalid SKI hex: {}", parts[0]))?; if ski_vec.len() != 20 { return Err(anyhow!("SKI must be exactly 20 bytes")); } let mut ski = [0u8; 20]; ski.copy_from_slice(&ski_vec); - let asn = u32::from_str(parts[1]) - .with_context(|| format!("invalid asn: {}", parts[1]))?; - let spki = decode_hex(parts[2]) - .with_context(|| format!("invalid SPKI hex: {}", parts[2]))?; + let asn = u32::from_str(parts[1]).with_context(|| format!("invalid asn: {}", parts[1]))?; + let spki = decode_hex(parts[2]).with_context(|| format!("invalid SPKI hex: {}", parts[2]))?; validate_router_key(asn, &spki)?; @@ -254,13 +247,9 @@ fn validate_aspa(customer_asn: u32, provider_asns: &[u32]) -> Result<()> { } fn validate_router_key(asn: u32, spki: &[u8]) -> Result<()> { - crate::rtr::payload::RouterKey::new( - Ski::default(), - Asn::from(asn), - spki.to_vec(), - ) - .validate() - .map_err(|err| anyhow!(err.to_string()))?; + crate::rtr::payload::RouterKey::new(Ski::default(), Asn::from(asn), spki.to_vec()) + .validate() + .map_err(|err| anyhow!(err.to_string()))?; Ok(()) } @@ -309,4 +298,3 @@ fn decode_hex(input: &str) -> Result> { }) .collect() } - diff --git a/src/rtr/mod.rs b/src/rtr/mod.rs index 1a86bc3..1cbfcbb 100644 --- a/src/rtr/mod.rs +++ b/src/rtr/mod.rs @@ -1,10 +1,9 @@ -pub mod pdu; pub mod cache; -pub mod payload; -pub mod store; -pub mod session; pub mod error_type; -pub mod state; -pub mod server; pub mod loader; -pub mod ccr; +pub mod payload; +pub mod pdu; +pub mod server; +pub mod session; +pub mod state; +pub mod store; diff --git a/src/rtr/payload.rs b/src/rtr/payload.rs index 826599e..a5aebbb 100644 --- a/src/rtr/payload.rs +++ b/src/rtr/payload.rs @@ -1,13 +1,12 @@ +use crate::data_model::resources::as_resources::Asn; +use crate::data_model::resources::ip_resources::IPAddressPrefix; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::io; use std::time::Duration; -use serde::{Deserialize, Serialize}; -use crate::data_model::resources::as_resources::Asn; -use crate::data_model::resources::ip_resources::IPAddressPrefix; use x509_parser::prelude::FromDer; use x509_parser::x509::SubjectPublicKeyInfo; - #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] enum PayloadPduType { Ipv4Prefix = 4, @@ -16,7 +15,9 @@ enum PayloadPduType { Aspa = 11, } -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[derive( + Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize, +)] pub struct Ski([u8; 20]); impl AsRef<[u8]> for Ski { @@ -60,7 +61,6 @@ impl RouteOrigin { } } - #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] pub struct RouterKey { subject_key_identifier: Ski, @@ -104,8 +104,8 @@ impl RouterKey { )); } - let (rem, _) = SubjectPublicKeyInfo::from_der(&self.subject_public_key_info) - .map_err(|err| { + let (rem, _) = + SubjectPublicKeyInfo::from_der(&self.subject_public_key_info).map_err(|err| { io::Error::new( io::ErrorKind::InvalidData, format!("RouterKey SPKI is not valid DER: {err}"), @@ -115,10 +115,7 @@ impl RouterKey { if !rem.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!( - "RouterKey SPKI DER has trailing bytes: {}", - rem.len() - ), + format!("RouterKey SPKI DER has trailing bytes: {}", rem.len()), )); } @@ -177,7 +174,6 @@ impl Aspa { } } - #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub enum Payload { @@ -191,7 +187,6 @@ pub enum Payload { Aspa(Aspa), } - // Timing #[derive(Clone, Copy, Debug)] pub struct Timing { @@ -202,7 +197,7 @@ pub struct Timing { pub retry: u32, /// The number of secionds before data expires if not refreshed. - pub expire: u32 + pub expire: u32, } impl Timing { @@ -214,7 +209,11 @@ impl Timing { pub const MAX_EXPIRE: u32 = 172_800; pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self { - Self { refresh, retry, expire } + Self { + refresh, + retry, + expire, + } } pub fn validate(self) -> Result<(), io::Error> { @@ -223,7 +222,9 @@ impl Timing { io::ErrorKind::InvalidData, format!( "refresh interval {} out of range {}..={}", - self.refresh, Self::MIN_REFRESH, Self::MAX_REFRESH + self.refresh, + Self::MIN_REFRESH, + Self::MAX_REFRESH ), )); } @@ -233,7 +234,9 @@ impl Timing { io::ErrorKind::InvalidData, format!( "retry interval {} out of range {}..={}", - self.retry, Self::MIN_RETRY, Self::MAX_RETRY + self.retry, + Self::MIN_RETRY, + Self::MAX_RETRY ), )); } @@ -243,7 +246,9 @@ impl Timing { io::ErrorKind::InvalidData, format!( "expire interval {} out of range {}..={}", - self.expire, Self::MIN_EXPIRE, Self::MAX_EXPIRE + self.expire, + Self::MIN_EXPIRE, + Self::MAX_EXPIRE ), )); } @@ -282,7 +287,6 @@ impl Timing { pub fn expire(self) -> Duration { Duration::from_secs(u64::from(self.expire)) } - } impl Default for Timing { diff --git a/src/rtr/pdu.rs b/src/rtr/pdu.rs index 998986c..1a77e3d 100644 --- a/src/rtr/pdu.rs +++ b/src/rtr/pdu.rs @@ -1,16 +1,16 @@ -use std::{cmp, mem}; -use std::net::{Ipv4Addr, Ipv6Addr}; -use std::sync::Arc; use crate::data_model::resources::as_resources::Asn; use crate::rtr::error_type::ErrorCode; use crate::rtr::payload::{Ski, Timing}; -use std::io; -use tokio::io::{AsyncWrite}; use anyhow::Result; +use std::io; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::Arc; +use std::{cmp, mem}; +use tokio::io::AsyncWrite; -use std::slice; use anyhow::bail; use serde::Serialize; +use std::slice; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; pub const HEADER_LEN: usize = 8; @@ -27,10 +27,7 @@ macro_rules! common { #[allow(dead_code)] impl $type { /// Writes a value to a writer. - pub async fn write( - &self, - a: &mut A - ) -> Result<(), io::Error> { + pub async fn write(&self, a: &mut A) -> Result<(), io::Error> { a.write_all(self.as_ref()).await } } @@ -38,10 +35,7 @@ macro_rules! common { impl AsRef<[u8]> for $type { fn as_ref(&self) -> &[u8] { unsafe { - slice::from_raw_parts( - self as *const Self as *const u8, - mem::size_of::() - ) + slice::from_raw_parts(self as *const Self as *const u8, mem::size_of::()) } } } @@ -49,14 +43,11 @@ macro_rules! common { impl AsMut<[u8]> for $type { fn as_mut(&mut self) -> &mut [u8] { unsafe { - slice::from_raw_parts_mut( - self as *mut Self as *mut u8, - mem::size_of::() - ) + slice::from_raw_parts_mut(self as *mut Self as *mut u8, mem::size_of::()) } } } - } + }; } macro_rules! concrete { @@ -94,28 +85,20 @@ macro_rules! concrete { /// /// If a value with a different PDU type is received, returns an /// error. - pub async fn read( - sock: &mut Sock - ) -> Result { + pub async fn read(sock: &mut Sock) -> Result { let mut res = Self::default(); sock.read_exact(res.header.as_mut()).await?; if res.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, - concat!( - "PDU type mismatch when expecting ", - stringify!($type) - ) - )) + concat!("PDU type mismatch when expecting ", stringify!($type)), + )); } if res.header.length() as usize != res.as_ref().len() { return Err(io::Error::new( io::ErrorKind::InvalidData, - concat!( - "invalid length for ", - stringify!($type) - ) - )) + concat!("invalid length for ", stringify!($type)), + )); } sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; Ok(res) @@ -126,32 +109,26 @@ macro_rules! concrete { /// If a different PDU type is received, returns the header as /// the error case of the ok case. pub async fn try_read( - sock: &mut Sock + sock: &mut Sock, ) -> Result, io::Error> { let mut res = Self::default(); sock.read_exact(res.header.as_mut()).await?; if res.header.pdu() == ErrorReport::PDU { // Since we should drop the session after an error, we // can safely ignore all the rest of the error for now. - return Ok(Err(res.header)) + return Ok(Err(res.header)); } if res.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, - concat!( - "PDU type mismatch when expecting ", - stringify!($type) - ) - )) + concat!("PDU type mismatch when expecting ", stringify!($type)), + )); } if res.header.length() as usize != res.as_ref().len() { return Err(io::Error::new( io::ErrorKind::InvalidData, - concat!( - "invalid length for ", - stringify!($type) - ) - )) + concat!("invalid length for ", stringify!($type)), + )); } sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; Ok(Ok(res)) @@ -163,17 +140,14 @@ macro_rules! concrete { /// `header`, the function reads the rest of the PUD from the /// reader and returns the complete value. pub async fn read_payload( - header: Header, sock: &mut Sock + header: Header, + sock: &mut Sock, ) -> Result { if header.length() as usize != mem::size_of::() { return Err(io::Error::new( io::ErrorKind::InvalidData, - concat!( - "invalid length for ", - stringify!($type), - " PDU" - ) - )) + concat!("invalid length for ", stringify!($type), " PDU"), + )); } let mut res = Self::default(); sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; @@ -181,10 +155,9 @@ macro_rules! concrete { Ok(res) } } - } + }; } - // 所有PDU公共头部信息 #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] @@ -196,7 +169,6 @@ pub struct Header { } impl Header { - const LEN: usize = mem::size_of::(); pub fn new(version: u8, pdu: u8, session: u16, length: u32) -> Self { Header { @@ -208,7 +180,7 @@ impl Header { } pub async fn read_raw( - sock: &mut S + sock: &mut S, ) -> Result<[u8; HEADER_LEN], io::Error> { let mut buf = [0u8; HEADER_LEN]; sock.read_exact(&mut buf).await?; @@ -229,10 +201,7 @@ impl Header { } if length > MAX_PDU_LEN { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "PDU too large", - )); + return Err(io::Error::new(io::ErrorKind::InvalidData, "PDU too large")); } Ok(Self { @@ -247,13 +216,21 @@ impl Header { Self::from_raw(Self::read_raw(sock).await?) } - pub fn version(self) -> u8{self.version} + pub fn version(self) -> u8 { + self.version + } - pub fn pdu(self) -> u8{self.pdu} + pub fn pdu(self) -> u8 { + self.pdu + } - pub fn session_id(self) -> u16{u16::from_be(self.session_id)} + pub fn session_id(self) -> u16 { + u16::from_be(self.session_id) + } - pub fn length(self) -> u32{u32::from_be(self.length)} + pub fn length(self) -> u32 { + u32::from_be(self.length) + } pub fn pdu_len(self) -> Result { usize::try_from(self.length()).map_err(|_| { @@ -268,7 +245,6 @@ impl Header { debug_assert_eq!(self.pdu(), ErrorReport::PDU); self.session_id() } - } common!(Header); @@ -304,12 +280,10 @@ impl HeaderWithFlags { let pdu = buf[1]; let flags = buf[2]; let zero = buf[3]; - let length = u32::from_be_bytes([ - buf[4], buf[5], buf[6], buf[7], - ]); + let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]); // 3. 基础合法性校验 - if length < HEADER_LEN as u32{ + if length < HEADER_LEN as u32 { bail!("Invalid PDU length"); } @@ -327,18 +301,27 @@ impl HeaderWithFlags { }) } - pub fn version(self) -> u8{self.version} + pub fn version(self) -> u8 { + self.version + } - pub fn pdu(self) -> u8{self.pdu} + pub fn pdu(self) -> u8 { + self.pdu + } - pub fn flags(self) -> Flags{Flags(self.flags)} + pub fn flags(self) -> Flags { + Flags(self.flags) + } - pub fn zero(self) -> u8 { self.zero } + pub fn zero(self) -> u8 { + self.zero + } - pub fn length(self) -> u32{u32::from_be(self.length)} + pub fn length(self) -> u32 { + u32::from_be(self.length) + } } - // Serial Notify #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] @@ -351,20 +334,18 @@ impl SerialNotify { pub const PDU: u8 = 0; pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { - SerialNotify{ + SerialNotify { header: Header::new(version, Self::PDU, session_id, Self::size()), - serial_number: serial_number.to_be() + serial_number: serial_number.to_be(), } } pub fn serial_number(self) -> u32 { u32::from_be(self.serial_number) } - } concrete!(SerialNotify); - // Serial Query #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] @@ -377,9 +358,9 @@ impl SerialQuery { pub const PDU: u8 = 1; pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { - SerialQuery{ + SerialQuery { header: Header::new(version, Self::PDU, session_id, Self::size()), - serial_number: serial_number.to_be() + serial_number: serial_number.to_be(), } } @@ -390,12 +371,11 @@ impl SerialQuery { concrete!(SerialQuery); - // Reset Query #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct ResetQuery { - header: Header + header: Header, } impl ResetQuery { @@ -410,7 +390,6 @@ impl ResetQuery { concrete!(ResetQuery); - // Cache Response #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] @@ -430,7 +409,6 @@ impl CacheResponse { concrete!(CacheResponse); - // Flags #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct Flags(u8); @@ -464,7 +442,7 @@ pub struct IPv4Prefix { max_len: u8, zero: u8, prefix: u32, - asn: u32 + asn: u32, } impl IPv4Prefix { @@ -475,7 +453,7 @@ impl IPv4Prefix { prefix_len: u8, max_len: u8, prefix: Ipv4Addr, - asn: Asn + asn: Asn, ) -> Self { IPv4Prefix { header: Header::new(version, Self::PDU, ZERO_16, IPV4_PREFIX_LEN), @@ -488,12 +466,22 @@ impl IPv4Prefix { } } - pub fn flag(self) -> Flags{self.flags} + pub fn flag(self) -> Flags { + self.flags + } - pub fn prefix_len(self) -> u8{self.prefix_len} - pub fn max_len(self) -> u8{self.max_len} - pub fn prefix(self) -> Ipv4Addr{u32::from_be(self.prefix).into()} - pub fn asn(self) -> Asn{u32::from_be(self.asn).into()} + pub fn prefix_len(self) -> u8 { + self.prefix_len + } + pub fn max_len(self) -> u8 { + self.max_len + } + pub fn prefix(self) -> Ipv4Addr { + u32::from_be(self.prefix).into() + } + pub fn asn(self) -> Asn { + u32::from_be(self.asn).into() + } } concrete!(IPv4Prefix); @@ -509,7 +497,7 @@ pub struct IPv6Prefix { max_len: u8, zero: u8, prefix: u128, - asn: u32 + asn: u32, } impl IPv6Prefix { @@ -520,7 +508,7 @@ impl IPv6Prefix { prefix_len: u8, max_len: u8, prefix: Ipv6Addr, - asn: Asn + asn: Asn, ) -> Self { IPv6Prefix { header: Header::new(version, Self::PDU, ZERO_16, IPV6_PREFIX_LEN), @@ -533,17 +521,26 @@ impl IPv6Prefix { } } - pub fn flag(self) -> Flags{self.flags} + pub fn flag(self) -> Flags { + self.flags + } - pub fn prefix_len(self) -> u8{self.prefix_len} - pub fn max_len(self) -> u8{self.max_len} - pub fn prefix(self) -> Ipv6Addr{u128::from_be(self.prefix).into()} - pub fn asn(self) -> Asn{u32::from_be(self.asn).into()} + pub fn prefix_len(self) -> u8 { + self.prefix_len + } + pub fn max_len(self) -> u8 { + self.max_len + } + pub fn prefix(self) -> Ipv6Addr { + u128::from_be(self.prefix).into() + } + pub fn asn(self) -> Asn { + u32::from_be(self.asn).into() + } } concrete!(IPv6Prefix); - // End of Data #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize)] pub enum EndOfData { @@ -559,14 +556,20 @@ impl EndOfData { timing: Timing, ) -> Result { if version == 0 { - Ok(EndOfData::V0(EndOfDataV0::new(version, session_id, serial_number))) - } - else { - Ok(EndOfData::V1(EndOfDataV1::new(version, session_id, serial_number, timing)?)) + Ok(EndOfData::V0(EndOfDataV0::new( + version, + session_id, + serial_number, + ))) + } else { + Ok(EndOfData::V1(EndOfDataV1::new( + version, + session_id, + serial_number, + timing, + )?)) } } - - } #[repr(C, packed)] @@ -587,11 +590,12 @@ impl EndOfDataV0 { } } - pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} + pub fn serial_number(self) -> u32 { + u32::from_be(self.serial_number) + } } concrete!(EndOfDataV0); - #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct EndOfDataV1 { @@ -601,7 +605,6 @@ pub struct EndOfDataV1 { refresh_interval: u32, retry_interval: u32, expire_interval: u32, - } impl EndOfDataV1 { @@ -640,9 +643,11 @@ impl EndOfDataV1 { }) } - pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} + pub fn serial_number(self) -> u32 { + u32::from_be(self.serial_number) + } - pub fn timing(self) -> Timing{ + pub fn timing(self) -> Timing { Timing { refresh: u32::from_be(self.refresh_interval), retry: u32::from_be(self.retry_interval), @@ -654,22 +659,20 @@ impl EndOfDataV1 { self.timing().validate() } - pub async fn read( - sock: &mut Sock - ) -> Result { + pub async fn read(sock: &mut Sock) -> Result { let mut res = Self::default(); sock.read_exact(res.header.as_mut()).await?; if res.header.pdu() != Self::PDU { return Err(io::Error::new( io::ErrorKind::InvalidData, "PDU type mismatch when expecting EndOfDataV1", - )) + )); } if res.header.length() as usize != mem::size_of::() { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid length for EndOfDataV1", - )) + )); } sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; res.validate()?; @@ -677,13 +680,14 @@ impl EndOfDataV1 { } pub async fn read_payload( - header: Header, sock: &mut Sock + header: Header, + sock: &mut Sock, ) -> Result { if header.length() as usize != mem::size_of::() { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid length for EndOfDataV1 PDU", - )) + )); } let mut res = Self::default(); sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; @@ -704,23 +708,21 @@ pub struct CacheReset { impl CacheReset { pub const PDU: u8 = 8; - pub fn new(version: u8) -> Self{ + pub fn new(version: u8) -> Self { CacheReset { - header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32) + header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32), } } } concrete!(CacheReset); - // Error Report #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] pub struct ErrorReport { octets: Vec, } - impl ErrorReport { /// The PDU type of an error PDU. pub const PDU: u8 = 10; @@ -741,27 +743,19 @@ impl ErrorReport { let text_len = cmp::min(text.len(), text_room); let size = Self::FIXED_PART_LEN + pdu_len + text_len; - let header = Header::new( - version, 10, error_code, u32::try_from(size).unwrap() - ); + let header = Header::new(version, 10, error_code, u32::try_from(size).unwrap()); let mut octets = Vec::with_capacity(size); octets.extend_from_slice(header.as_ref()); - octets.extend_from_slice( - u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref() - ); + octets.extend_from_slice(u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref()); octets.extend_from_slice(&pdu[..pdu_len]); - octets.extend_from_slice( - u32::try_from(text_len).unwrap().to_be_bytes().as_ref() - ); + octets.extend_from_slice(u32::try_from(text_len).unwrap().to_be_bytes().as_ref()); octets.extend_from_slice(&text[..text_len]); ErrorReport { octets } } - pub async fn read( - sock: &mut Sock - ) -> Result { + pub async fn read(sock: &mut Sock) -> Result { let header = Header::read(sock).await?; if header.pdu() != Self::PDU { return Err(io::Error::new( @@ -787,7 +781,8 @@ impl ErrorReport { let mut octets = Vec::with_capacity(total_len); octets.extend_from_slice(header.as_ref()); octets.resize(total_len, 0); - sock.read_exact(&mut octets[mem::size_of::
()..]).await?; + sock.read_exact(&mut octets[mem::size_of::
()..]) + .await?; let res = ErrorReport { octets }; res.validate()?; @@ -813,7 +808,8 @@ impl ErrorReport { /// Skips over the payload of the error PDU. pub async fn skip_payload( - header: Header, sock: &mut Sock + header: Header, + sock: &mut Sock, ) -> Result<(), io::Error> { let Some(mut remaining) = header.pdu_len()?.checked_sub(mem::size_of::
()) else { return Err(io::Error::new( @@ -840,9 +836,7 @@ impl ErrorReport { } /// Writes the PUD to a writer. - pub async fn write( - &self, a: &mut A - ) -> Result<(), io::Error> { + pub async fn write(&self, a: &mut A) -> Result<(), io::Error> { a.write_all(self.as_ref()).await } @@ -860,7 +854,7 @@ impl ErrorReport { u32::from_be_bytes( self.octets[Header::LEN..Header::LEN + 4] .try_into() - .unwrap() + .unwrap(), ) as usize } @@ -876,11 +870,7 @@ impl ErrorReport { fn text_len(&self) -> usize { let offset = self.text_len_offset(); - u32::from_be_bytes( - self.octets[offset..offset + 4] - .try_into() - .unwrap() - ) as usize + u32::from_be_bytes(self.octets[offset..offset + 4].try_into().unwrap()) as usize } fn text_range(&self) -> std::ops::Range { @@ -916,7 +906,10 @@ impl ErrorReport { let pdu_len = self.erroneous_pdu_len(); let text_len_offset = Header::LEN + 4 + pdu_len; let Some(text_len_end) = text_len_offset.checked_add(4) else { - return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport length overflow")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ErrorReport length overflow", + )); }; if text_len_end > self.octets.len() { return Err(io::Error::new( @@ -928,10 +921,13 @@ impl ErrorReport { let text_len = u32::from_be_bytes( self.octets[text_len_offset..text_len_end] .try_into() - .unwrap() + .unwrap(), ) as usize; let Some(text_end) = text_len_end.checked_add(text_len) else { - return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport text overflow")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "ErrorReport text overflow", + )); }; if text_end != self.octets.len() { return Err(io::Error::new( @@ -951,7 +947,6 @@ impl ErrorReport { } } - // TODO: 补全 /// Router Key #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] @@ -966,13 +961,10 @@ pub struct RouterKey { } impl RouterKey { - pub const PDU: u8 = 9; const BASE_LEN: usize = HEADER_LEN + 20 + 4; - pub async fn read( - sock: &mut Sock - ) -> Result { + pub async fn read(sock: &mut Sock) -> Result { let header = HeaderWithFlags::read(sock) .await .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?; @@ -1022,25 +1014,19 @@ impl RouterKey { Ok(res) } - pub async fn write( - &self, - w: &mut A, - ) -> Result<(), io::Error> { + pub async fn write(&self, w: &mut A) -> Result<(), io::Error> { let length = Self::BASE_LEN + self.subject_public_key_info.len(); - let header = HeaderWithFlags::new( - self.header.version(), - Self::PDU, - self.flags, - length as u32, - ); + let header = + HeaderWithFlags::new(self.header.version(), Self::PDU, self.flags, length as u32); w.write_all(&[ header.version(), header.pdu(), header.flags().into_u8(), ZERO_8, - ]).await?; + ]) + .await?; w.write_all(&(length as u32).to_be_bytes()).await?; w.write_all(self.ski.as_ref()).await?; @@ -1120,24 +1106,20 @@ impl RouterKey { } } - - // ASPA #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] -pub struct Aspa{ +pub struct Aspa { header: HeaderWithFlags, customer_asn: u32, - provider_asns: Vec + provider_asns: Vec, } impl Aspa { pub const PDU: u8 = 11; const BASE_LEN: usize = HEADER_LEN + 4; - pub async fn read( - sock: &mut Sock - ) -> Result { + pub async fn read(sock: &mut Sock) -> Result { let header = HeaderWithFlags::read(sock) .await .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?; @@ -1192,11 +1174,7 @@ impl Aspa { Ok(res) } - pub async fn write( - &self, - w: &mut A, - ) -> Result<(), io::Error> { - + pub async fn write(&self, w: &mut A) -> Result<(), io::Error> { let length = Self::BASE_LEN + (self.provider_asns.len() * 4); let header = HeaderWithFlags::new( @@ -1211,7 +1189,8 @@ impl Aspa { header.pdu(), header.flags().into_u8(), ZERO_8, - ]).await?; + ]) + .await?; w.write_all(&(length as u32).to_be_bytes()).await?; w.write_all(&self.customer_asn.to_be_bytes()).await?; @@ -1222,12 +1201,7 @@ impl Aspa { Ok(()) } - pub fn new( - version: u8, - flags: Flags, - customer_asn: u32, - provider_asns: Vec, - ) -> Self { + pub fn new(version: u8, flags: Flags, customer_asn: u32, provider_asns: Vec) -> Self { let length = Self::BASE_LEN + (provider_asns.len() * 4); Self { @@ -1306,7 +1280,6 @@ impl Aspa { } } - //--- AsRef and AsMut impl AsRef<[u8]> for ErrorReport { fn as_ref(&self) -> &[u8] { @@ -1319,4 +1292,3 @@ impl AsMut<[u8]> for ErrorReport { self.octets.as_mut() } } - diff --git a/src/rtr/server/connection.rs b/src/rtr/server/connection.rs index 088899d..7dbed16 100644 --- a/src/rtr/server/connection.rs +++ b/src/rtr/server/connection.rs @@ -6,7 +6,7 @@ use std::sync::{ use anyhow::{Context, Result, anyhow}; use tokio::net::TcpStream; -use tokio::sync::{broadcast, watch, OwnedSemaphorePermit}; +use tokio::sync::{OwnedSemaphorePermit, broadcast, watch}; use tracing::{info, warn}; use x509_parser::extensions::GeneralName; use x509_parser::prelude::{FromDer, X509Certificate}; @@ -22,10 +22,7 @@ pub struct ConnectionGuard { } impl ConnectionGuard { - pub fn new( - active_connections: Arc, - permit: OwnedSemaphorePermit, - ) -> Self { + pub fn new(active_connections: Arc, permit: OwnedSemaphorePermit) -> Self { active_connections.fetch_add(1, Ordering::Relaxed); Self { active_connections, @@ -72,8 +69,12 @@ pub async fn handle_tls_connection( .await .with_context(|| format!("TLS handshake failed for {}", peer_addr))?; info!("RTR TLS handshake completed for {}", peer_addr); - verify_peer_certificate_ip(&tls_stream, peer_addr.ip()) - .with_context(|| format!("TLS client certificate SAN IP validation failed for {}", peer_addr))?; + verify_peer_certificate_ip(&tls_stream, peer_addr.ip()).with_context(|| { + format!( + "TLS client certificate SAN IP validation failed for {}", + peer_addr + ) + })?; info!("RTR TLS client certificate validated for {}", peer_addr); let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx); @@ -122,16 +123,16 @@ fn verify_peer_certificate_ip( GeneralName::IPAddress(bytes) => { let bytes = *bytes; match (peer_ip, bytes.len()) { - (IpAddr::V4(ip), 4) => <[u8; 4]>::try_from(bytes) - .map(IpAddr::from) - .map(|cert_ip| cert_ip == IpAddr::V4(ip)) - .unwrap_or(false), - (IpAddr::V6(ip), 16) => <[u8; 16]>::try_from(bytes) - .map(IpAddr::from) - .map(|cert_ip| cert_ip == IpAddr::V6(ip)) - .unwrap_or(false), - _ => false, - } + (IpAddr::V4(ip), 4) => <[u8; 4]>::try_from(bytes) + .map(IpAddr::from) + .map(|cert_ip| cert_ip == IpAddr::V4(ip)) + .unwrap_or(false), + (IpAddr::V6(ip), 16) => <[u8; 16]>::try_from(bytes) + .map(IpAddr::from) + .map(|cert_ip| cert_ip == IpAddr::V6(ip)) + .unwrap_or(false), + _ => false, + } } _ => false, }); diff --git a/src/rtr/server/listener.rs b/src/rtr/server/listener.rs index 700e746..ed6fafa 100644 --- a/src/rtr/server/listener.rs +++ b/src/rtr/server/listener.rs @@ -1,28 +1,22 @@ use std::net::SocketAddr; use std::path::Path; -use std::sync::{ - Arc, - atomic::AtomicUsize, -}; +use std::sync::{Arc, atomic::AtomicUsize}; use std::time::Duration; use anyhow::{Context, Result}; use socket2::{SockRef, TcpKeepalive}; use tokio::net::TcpListener; -use tokio::sync::{broadcast, watch, Semaphore}; +use tokio::sync::{Semaphore, broadcast, watch}; use tracing::{info, warn}; use rustls::ServerConfig; use tokio_rustls::TlsAcceptor; use crate::rtr::cache::SharedRtrCache; -use crate::rtr::server::connection::{ - ConnectionGuard, - handle_tcp_connection, - handle_tls_connection, - is_expected_disconnect, -}; use crate::rtr::server::config::RtrServiceConfig; +use crate::rtr::server::connection::{ + ConnectionGuard, handle_tcp_connection, handle_tls_connection, is_expected_disconnect, +}; use crate::rtr::server::tls::load_rustls_server_config_with_options; pub struct RtrServer { @@ -65,7 +59,8 @@ impl RtrServer { } pub fn active_connections(&self) -> usize { - self.active_connections.load(std::sync::atomic::Ordering::Relaxed) + self.active_connections + .load(std::sync::atomic::Ordering::Relaxed) } pub async fn run_tcp(self) -> Result<()> { @@ -293,10 +288,7 @@ impl RtrServer { } } -fn apply_keepalive( - stream: &tokio::net::TcpStream, - keepalive: Option, -) -> Result<()> { +fn apply_keepalive(stream: &tokio::net::TcpStream, keepalive: Option) -> Result<()> { let Some(keepalive) = keepalive else { return Ok(()); }; diff --git a/src/rtr/server/mod.rs b/src/rtr/server/mod.rs index 3d833af..53ced47 100644 --- a/src/rtr/server/mod.rs +++ b/src/rtr/server/mod.rs @@ -9,4 +9,4 @@ pub use config::RtrServiceConfig; pub use listener::RtrServer; pub use notifier::RtrNotifier; pub use service::{RtrService, RunningRtrService}; -pub use tls::load_rustls_server_config; \ No newline at end of file +pub use tls::load_rustls_server_config; diff --git a/src/rtr/server/notifier.rs b/src/rtr/server/notifier.rs index 208b18e..075626a 100644 --- a/src/rtr/server/notifier.rs +++ b/src/rtr/server/notifier.rs @@ -13,4 +13,4 @@ impl RtrNotifier { pub fn notify_cache_updated(&self) { let _ = self.tx.send(()); } -} \ No newline at end of file +} diff --git a/src/rtr/server/service.rs b/src/rtr/server/service.rs index cdb7f53..b86a3a1 100644 --- a/src/rtr/server/service.rs +++ b/src/rtr/server/service.rs @@ -5,7 +5,7 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, }; -use tokio::sync::{broadcast, watch, Semaphore}; +use tokio::sync::{Semaphore, broadcast, watch}; use tokio::task::JoinHandle; use tracing::{error, warn}; @@ -114,7 +114,10 @@ impl RtrService { let server = self.tls_server(bind_addr); tokio::spawn(async move { - if let Err(err) = server.run_tls_from_pem(cert_path, key_path, client_ca_path).await { + if let Err(err) = server + .run_tls_from_pem(cert_path, key_path, client_ca_path) + .await + { error!("RTR TLS server {} exited with error: {:?}", bind_addr, err); } }) @@ -129,7 +132,8 @@ impl RtrService { client_ca_path: impl AsRef, ) -> RunningRtrService { let tcp_handle = self.spawn_tcp(tcp_bind_addr); - let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path, client_ca_path); + let tls_handle = + self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path, client_ca_path); RunningRtrService { shutdown_tx: self.shutdown_tx.clone(), diff --git a/src/rtr/server/tls.rs b/src/rtr/server/tls.rs index ca7f02c..ead4164 100644 --- a/src/rtr/server/tls.rs +++ b/src/rtr/server/tls.rs @@ -3,7 +3,7 @@ use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result, anyhow}; use rustls::server::WebPkiClientVerifier; use rustls::{RootCertStore, ServerConfig}; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; @@ -36,8 +36,12 @@ pub fn load_rustls_server_config_with_options( let key = load_private_key(&key_path) .with_context(|| format!("failed to load private key from {}", key_path.display()))?; - let client_ca_certs = load_certs(&client_ca_path) - .with_context(|| format!("failed to load client CA certs from {}", client_ca_path.display()))?; + let client_ca_certs = load_certs(&client_ca_path).with_context(|| { + format!( + "failed to load client CA certs from {}", + client_ca_path.display() + ) + })?; let mut client_roots = RootCertStore::empty(); let (added, _) = client_roots.add_parsable_certificates(client_ca_certs); if added == 0 { @@ -100,8 +104,7 @@ fn load_certs(path: &Path) -> Result>> { let file = File::open(path)?; let mut reader = BufReader::new(file); - let certs = rustls_pemfile::certs(&mut reader) - .collect::, _>>()?; + let certs = rustls_pemfile::certs(&mut reader).collect::, _>>()?; if certs.is_empty() { return Err(anyhow!("no certificates found in {}", path.display())); diff --git a/src/rtr/session.rs b/src/rtr/session.rs index f5e09c3..922d072 100644 --- a/src/rtr/session.rs +++ b/src/rtr/session.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -use anyhow::{anyhow, bail, Result}; +use anyhow::{Result, anyhow, bail}; use tokio::io; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{broadcast, watch}; @@ -14,13 +14,11 @@ use crate::rtr::cache::{ validate_payloads_for_rtr, }; use crate::rtr::error_type::ErrorCode; -use crate::rtr::pdu::{ - Aspa as AspaPdu, - CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, Header, IPv4Prefix, IPv6Prefix, - ResetQuery, RouterKey as RouterKeyPdu, SerialNotify, SerialQuery, - HEADER_LEN, -}; use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey}; +use crate::rtr::pdu::{ + Aspa as AspaPdu, CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, HEADER_LEN, Header, + IPv4Prefix, IPv6Prefix, ResetQuery, RouterKey as RouterKeyPdu, SerialNotify, SerialQuery, +}; const SUPPORTED_MAX_VERSION: u8 = 2; const SUPPORTED_MIN_VERSION: u8 = 0; @@ -83,10 +81,7 @@ where } async fn run_inner(&mut self) -> Result<()> { - info!( - "RTR session started: {}", - self.session_summary() - ); + info!("RTR session started: {}", self.session_summary()); loop { let transport_timeout = self.transport_timeout(); tokio::select! { @@ -265,7 +260,10 @@ where self.session_summary() ); } else { - debug!("RTR session transport shutdown completed: {}", self.session_summary()); + debug!( + "RTR session transport shutdown completed: {}", + self.session_summary() + ); } } @@ -351,7 +349,7 @@ where offending_pdu, msg.as_bytes(), ) - .await + .await } async fn send_unexpected_version( @@ -362,8 +360,7 @@ where ) -> io::Result<()> { let msg = format!( "unexpected protocol version {}, established version is {}", - received_version, - established_version + received_version, established_version ); self.send_error( @@ -372,7 +369,7 @@ where offending_pdu, msg.as_bytes(), ) - .await + .await } async fn send_corrupt_session_id( @@ -393,7 +390,7 @@ where offending_pdu, msg.as_bytes(), ) - .await + .await } async fn send_corrupt_data( @@ -402,13 +399,8 @@ where offending_pdu: &[u8], detail: &[u8], ) -> io::Result<()> { - self.send_error( - version, - ErrorCode::CorruptData, - offending_pdu, - detail, - ) - .await + self.send_error(version, ErrorCode::CorruptData, offending_pdu, detail) + .await } async fn send_no_data_available( @@ -458,7 +450,8 @@ where self.state = SessionState::Closed; bail!( "router version {} higher than cache max {}", - version, SUPPORTED_MAX_VERSION + version, + SUPPORTED_MAX_VERSION ); } self.negotiate_version(version).await?; @@ -495,14 +488,16 @@ where self.state = SessionState::Closed; bail!( "router version {} higher than cache max {}", - version, SUPPORTED_MAX_VERSION + version, + SUPPORTED_MAX_VERSION ); } self.negotiate_version(version).await?; let session_id = query.session_id(); let serial = query.serial_number(); - self.handle_serial(version, session_id, serial, query.as_ref()).await?; + self.handle_serial(version, session_id, serial, query.as_ref()) + .await?; self.state = SessionState::Established; info!( "RTR session established after Serial Query: negotiated_version={}, client_session_id={}, client_serial={}, {}", @@ -613,7 +608,10 @@ where .cache .read() .map_err(|_| anyhow!("cache read lock poisoned"))?; - (cache.is_data_available(), cache.session_id_for_version(version)) + ( + cache.is_data_available(), + cache.session_id_for_version(version), + ) }; if !data_available { @@ -723,7 +721,10 @@ where let now = Instant::now(); if let Some(last) = self.last_notify_at { if now.duration_since(last) < NOTIFY_MIN_INTERVAL { - debug!("RTR session notify skipped due to rate limit: {}", self.session_summary()); + debug!( + "RTR session notify skipped due to rate limit: {}", + self.session_summary() + ); return Ok(()); } } @@ -824,8 +825,7 @@ where let version = self.version()?; debug!( "RTR session writing Cache Response: version={}, session_id={}", - version, - session_id + version, session_id ); CacheResponse::new(version, session_id) .write(&mut self.stream) @@ -835,10 +835,7 @@ where async fn write_cache_reset(&mut self) -> Result<()> { let version = self.version()?; - info!( - "RTR session writing Cache Reset: version={}", - version - ); + info!("RTR session writing Cache Reset: version={}", version); CacheReset::new(version).write(&mut self.stream).await?; Ok(()) } @@ -880,8 +877,7 @@ where // References: // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 - validate_payloads_for_rtr(payloads, announce) - .map_err(|err| anyhow!(err.to_string()))?; + validate_payloads_for_rtr(payloads, announce).map_err(|err| anyhow!(err.to_string()))?; let (route_origins, router_keys, aspas) = count_payloads(payloads); debug!( "RTR session sending snapshot payloads: announce={}, total={}, route_origins={}, router_keys={}, aspas={}", @@ -906,8 +902,7 @@ where // References: // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 - validate_payload_updates_for_rtr(&updates) - .map_err(|err| anyhow!(err.to_string()))?; + validate_payload_updates_for_rtr(&updates).map_err(|err| anyhow!(err.to_string()))?; let (announced, withdrawn, route_origins, router_keys, aspas) = count_payload_updates(&updates); debug!( @@ -1010,8 +1005,7 @@ where }); let providers = if announce { - aspa - .provider_asns() + aspa.provider_asns() .iter() .map(|asn| asn.into_u32()) .collect::>() @@ -1019,18 +1013,12 @@ where Vec::new() }; - let pdu = AspaPdu::new( - version, - flags, - aspa.customer_asn().into_u32(), - providers, - ); + let pdu = AspaPdu::new(version, flags, aspa.customer_asn().into_u32(), providers); pdu.write(&mut self.stream).await?; Ok(()) } - async fn send_error( &mut self, version: u8, @@ -1052,11 +1040,7 @@ where .await } - async fn handle_pdu_read_error( - &mut self, - header: Header, - err: io::Error, - ) -> Result<()> { + async fn handle_pdu_read_error(&mut self, header: Header, err: io::Error) -> Result<()> { warn!( "RTR session failed to read established-session PDU payload: pdu={}, version={}, err={}", header.pdu(), @@ -1076,11 +1060,7 @@ where Ok(()) } - async fn handle_first_pdu_read_error( - &mut self, - header: Header, - err: io::Error, - ) -> Result<()> { + async fn handle_first_pdu_read_error(&mut self, header: Header, err: io::Error) -> Result<()> { warn!( "RTR session failed to read first PDU payload: pdu={}, version={}, err={}", header.pdu(), @@ -1089,13 +1069,12 @@ where ); if err.kind() == io::ErrorKind::InvalidData { let offending = self.read_full_pdu_bytes(header).await?; - let err_version = if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION) - .contains(&header.version()) - { - header.version() - } else { - SUPPORTED_MAX_VERSION - }; + let err_version = + if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION).contains(&header.version()) { + header.version() + } else { + SUPPORTED_MAX_VERSION + }; let detail = err.to_string(); let _ = self @@ -1114,13 +1093,14 @@ where ) -> Result<()> { warn!( "RTR session handling invalid header bytes: raw_header={:02X?}, err={}", - raw_header, - err + raw_header, err ); if err.kind() == io::ErrorKind::InvalidData { let version = match self.version { Some(version) => version, - None if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION).contains(&raw_header[0]) => { + None if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION) + .contains(&raw_header[0]) => + { raw_header[0] } None => SUPPORTED_MAX_VERSION, @@ -1139,10 +1119,7 @@ where async fn handle_transport_timeout(&mut self, offending_pdu: &[u8]) -> Result<()> { let version = self.version.unwrap_or(SUPPORTED_MAX_VERSION); let timeout = self.transport_timeout(); - let detail = format!( - "transport stalled for longer than {:?}", - timeout - ); + let detail = format!("transport stalled for longer than {:?}", timeout); warn!( "RTR session transport timeout: version={}, offending_pdu_len={}, timeout={:?}", version, @@ -1177,7 +1154,8 @@ where bytes.resize(total_len, 0); timeout( self.transport_timeout(), - self.stream.read_exact(&mut bytes[HEADER_LEN..HEADER_LEN + payload_len]), + self.stream + .read_exact(&mut bytes[HEADER_LEN..HEADER_LEN + payload_len]), ) .await .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "transport read timed out"))??; diff --git a/src/rtr/state.rs b/src/rtr/state.rs index e8f909f..0eb18c8 100644 --- a/src/rtr/state.rs +++ b/src/rtr/state.rs @@ -12,7 +12,7 @@ impl State { pub fn session_ids(self) -> SessionIds { self.session_ids } - + pub fn serial(self) -> u32 { self.serial } diff --git a/src/rtr/store.rs b/src/rtr/store.rs index 6434de8..fdefb0b 100644 --- a/src/rtr/store.rs +++ b/src/rtr/store.rs @@ -1,6 +1,6 @@ +use anyhow::{Result, anyhow}; use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch}; -use anyhow::{anyhow, Result}; -use serde::{de::DeserializeOwned, Serialize}; +use serde::{Serialize, de::DeserializeOwned}; use std::path::Path; use std::sync::Arc; use tokio::task; @@ -66,7 +66,10 @@ impl RtrStore { /// Common serialize/put. fn put_cf(&self, cf: &str, key: &[u8], value: &T) -> Result<()> { - let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + let cf_handle = self + .db + .cf_handle(cf) + .ok_or_else(|| anyhow!("CF not found"))?; let data = serde_json::to_vec(value)?; self.db.put_cf(cf_handle, key, data)?; Ok(()) @@ -74,7 +77,10 @@ impl RtrStore { /// Common get/deserialize. fn get_cf(&self, cf: &str, key: &[u8]) -> Result> { - let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + let cf_handle = self + .db + .cf_handle(cf) + .ok_or_else(|| anyhow!("CF not found"))?; if let Some(value) = self.db.get_cf(cf_handle, key)? { let obj = serde_json::from_slice(&value)?; Ok(Some(obj)) @@ -85,7 +91,10 @@ impl RtrStore { /// Common delete. fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> { - let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; + let cf_handle = self + .db + .cf_handle(cf) + .ok_or_else(|| anyhow!("CF not found"))?; self.db.delete_cf(cf_handle, key)?; Ok(()) } @@ -137,10 +146,12 @@ impl RtrStore { pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> { debug!( "RTR store persisting delta window metadata: min_serial={}, max_serial={}", - min_serial, - max_serial + min_serial, max_serial ); - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let meta_cf = self + .db + .cf_handle(CF_META) + .ok_or_else(|| anyhow!("CF_META not found"))?; let mut batch = WriteBatch::default(); batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?); batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?); @@ -150,7 +161,10 @@ impl RtrStore { pub fn clear_delta_window(&self) -> Result<()> { debug!("RTR store clearing delta window metadata"); - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let meta_cf = self + .db + .cf_handle(CF_META) + .ok_or_else(|| anyhow!("CF_META not found"))?; let mut batch = WriteBatch::default(); batch.delete_cf(meta_cf, META_DELTA_MIN); batch.delete_cf(meta_cf, META_DELTA_MAX); @@ -166,8 +180,7 @@ impl RtrStore { (Some(min), Some(max)) => { debug!( "RTR store loaded delta window metadata: min_serial={}, max_serial={}", - min, - max + min, max ); Ok(Some((min, max))) } @@ -189,7 +202,10 @@ impl RtrStore { // =============================== pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> { - let cf_handle = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let cf_handle = self + .db + .cf_handle(CF_SNAPSHOT) + .ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; let mut batch = WriteBatch::default(); let data = serde_json::to_vec(snapshot)?; batch.put_cf(cf_handle, b"current", data); @@ -206,8 +222,14 @@ impl RtrStore { } pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> { - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let snapshot_cf = self + .db + .cf_handle(CF_SNAPSHOT) + .ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self + .db + .cf_handle(CF_META) + .ok_or_else(|| anyhow!("CF_META not found"))?; let mut batch = WriteBatch::default(); batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); @@ -234,8 +256,14 @@ impl RtrStore { serial: u32, ) -> Result<()> { let mut batch = WriteBatch::default(); - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let snapshot_cf = self + .db + .cf_handle(CF_SNAPSHOT) + .ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self + .db + .cf_handle(CF_META) + .ok_or_else(|| anyhow!("CF_META not found"))?; batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?); @@ -266,15 +294,28 @@ impl RtrStore { snapshot.router_keys().len(), snapshot.aspas().len() ); - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; - let delta_cf = self.db.cf_handle(CF_DELTA).ok_or_else(|| anyhow!("CF_DELTA not found"))?; + let snapshot_cf = self + .db + .cf_handle(CF_SNAPSHOT) + .ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self + .db + .cf_handle(CF_META) + .ok_or_else(|| anyhow!("CF_META not found"))?; + let delta_cf = self + .db + .cf_handle(CF_DELTA) + .ok_or_else(|| anyhow!("CF_DELTA not found"))?; let mut batch = WriteBatch::default(); batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?); batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); - batch.put_cf(meta_cf, META_AVAILABILITY, serde_json::to_vec(&availability)?); + batch.put_cf( + meta_cf, + META_AVAILABILITY, + serde_json::to_vec(&availability)?, + ); if let Some(delta) = delta { debug!( @@ -283,7 +324,11 @@ impl RtrStore { delta.announced().len(), delta.withdrawn().len() ); - batch.put_cf(delta_cf, delta_key(delta.serial()), serde_json::to_vec(delta)?); + batch.put_cf( + delta_cf, + delta_key(delta.serial()), + serde_json::to_vec(delta)?, + ); } if clear_delta_window { @@ -318,8 +363,7 @@ impl RtrStore { } else { debug!( "RTR store found no stale delta records outside window [{}, {}]", - min_serial, - max_serial + min_serial, max_serial ); } for key in stale_keys { @@ -334,8 +378,14 @@ impl RtrStore { pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> { let mut batch = WriteBatch::default(); - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let snapshot_cf = self + .db + .cf_handle(CF_SNAPSHOT) + .ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self + .db + .cf_handle(CF_META) + .ok_or_else(|| anyhow!("CF_META not found"))?; batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); self.db.write(batch)?; @@ -352,8 +402,14 @@ impl RtrStore { task::spawn_blocking(move || { let mut batch = WriteBatch::default(); - let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; - let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; + let snapshot_cf = self + .db + .cf_handle(CF_SNAPSHOT) + .ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; + let meta_cf = self + .db + .cf_handle(CF_META) + .ok_or_else(|| anyhow!("CF_META not found"))?; batch.put_cf(snapshot_cf, b"current", snapshot_bytes); batch.put_cf(meta_cf, META_SERIAL, serial_bytes); self.db.write(batch)?; @@ -370,7 +426,9 @@ impl RtrStore { match (snapshot, state) { (Some(snap), Some(state)) => Ok(Some((snap, state))), (None, None) => Ok(None), - _ => Err(anyhow!("Inconsistent DB state: snapshot and state mismatch")), + _ => Err(anyhow!( + "Inconsistent DB state: snapshot and state mismatch" + )), } } @@ -380,7 +438,9 @@ impl RtrStore { match (snapshot, serial) { (Some(snap), Some(serial)) => Ok(Some((snap, serial))), (None, None) => Ok(None), - _ => Err(anyhow!("Inconsistent DB state: snapshot and serial mismatch")), + _ => Err(anyhow!( + "Inconsistent DB state: snapshot and serial mismatch" + )), } } @@ -413,8 +473,8 @@ impl RtrStore { for item in iter { let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; - let parsed = delta_key_serial(key.as_ref()) - .ok_or_else(|| anyhow!("Invalid delta key"))?; + let parsed = + delta_key_serial(key.as_ref()).ok_or_else(|| anyhow!("Invalid delta key"))?; if parsed <= serial { continue; @@ -430,8 +490,7 @@ impl RtrStore { pub fn load_delta_window(&self, min_serial: u32, max_serial: u32) -> Result> { info!( "RTR store loading persisted delta window: min_serial={}, max_serial={}", - min_serial, - max_serial + min_serial, max_serial ); let cf_handle = self .db @@ -442,8 +501,8 @@ impl RtrStore { for item in iter { let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; - let parsed = delta_key_serial(key.as_ref()) - .ok_or_else(|| anyhow!("Invalid delta key"))?; + let parsed = + delta_key_serial(key.as_ref()).ok_or_else(|| anyhow!("Invalid delta key"))?; // Restore by the persisted window bounds instead of load_deltas_since(). // The latter follows lexicographic key order and is not safe across serial @@ -493,7 +552,11 @@ impl RtrStore { Ok(keys) } - fn list_delta_keys_outside_window(&self, min_serial: u32, max_serial: u32) -> Result>> { + fn list_delta_keys_outside_window( + &self, + min_serial: u32, + max_serial: u32, + ) -> Result>> { let cf_handle = self .db .cf_handle(CF_DELTA) @@ -503,8 +566,8 @@ impl RtrStore { for item in iter { let (key, _value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; - let serial = delta_key_serial(key.as_ref()) - .ok_or_else(|| anyhow!("Invalid delta key"))?; + let serial = + delta_key_serial(key.as_ref()).ok_or_else(|| anyhow!("Invalid delta key"))?; if !serial_in_window(serial, min_serial, max_serial) { keys.push(key.to_vec()); } @@ -522,8 +585,7 @@ fn validate_delta_window(deltas: &[Delta], min_serial: u32, max_serial: u32) -> if deltas.is_empty() { warn!( "RTR store delta window validation failed: no persisted deltas for window [{}, {}]", - min_serial, - max_serial + min_serial, max_serial ); return Err(anyhow!( "delta window [{}, {}] has no persisted deltas", diff --git a/src/slurm/file.rs b/src/slurm/file.rs new file mode 100644 index 0000000..335a01a --- /dev/null +++ b/src/slurm/file.rs @@ -0,0 +1,251 @@ +use std::collections::BTreeSet; +use std::io; + +use crate::rtr::payload::Payload; +use crate::slurm::policy::{LocallyAddedAssertions, ValidationOutputFilters, prefix_encompasses}; + +#[derive(Debug, thiserror::Error)] +pub enum SlurmError { + #[error("failed to parse SLURM JSON: {0}")] + Parse(#[from] serde_json::Error), + + #[error("invalid SLURM file: {0}")] + Invalid(String), + + #[error("I/O error while reading SLURM file: {0}")] + Io(#[from] io::Error), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SlurmVersion { + V1, + V2, +} + +impl SlurmVersion { + pub const V1_U32: u32 = 1; + pub const V2_U32: u32 = 2; + + pub fn as_u32(self) -> u32 { + match self { + Self::V1 => Self::V1_U32, + Self::V2 => Self::V2_U32, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SlurmFile { + version: SlurmVersion, + validation_output_filters: ValidationOutputFilters, + locally_added_assertions: LocallyAddedAssertions, +} + +impl SlurmFile { + pub fn new( + version: SlurmVersion, + validation_output_filters: ValidationOutputFilters, + locally_added_assertions: LocallyAddedAssertions, + ) -> Result { + let slurm = Self { + version, + validation_output_filters, + locally_added_assertions, + }; + slurm.validate()?; + Ok(slurm) + } + + pub fn version(&self) -> SlurmVersion { + self.version + } + + pub fn validation_output_filters(&self) -> &ValidationOutputFilters { + &self.validation_output_filters + } + + pub fn locally_added_assertions(&self) -> &LocallyAddedAssertions { + &self.locally_added_assertions + } + + pub fn apply(&self, payloads: &[Payload]) -> Vec { + let mut seen = BTreeSet::new(); + let mut result = Vec::new(); + + for payload in payloads { + if self.validation_output_filters.matches(payload) { + continue; + } + + if seen.insert(payload.clone()) { + result.push(payload.clone()); + } + } + + for assertion in self.locally_added_assertions.to_payloads() { + if seen.insert(assertion.clone()) { + result.push(assertion); + } + } + + result + } + + pub fn merge_named(files: Vec<(String, SlurmFile)>) -> Result { + if files.is_empty() { + return Err(SlurmError::Invalid( + "SLURM directory does not contain any .slurm files".to_string(), + )); + } + + validate_cross_file_conflicts(&files)?; + + let mut version = SlurmVersion::V1; + let mut merged_filters = ValidationOutputFilters { + prefix_filters: Vec::new(), + bgpsec_filters: Vec::new(), + aspa_filters: Vec::new(), + }; + let mut merged_assertions = LocallyAddedAssertions { + prefix_assertions: Vec::new(), + bgpsec_assertions: Vec::new(), + aspa_assertions: Vec::new(), + }; + + for (_, file) in files { + if file.version == SlurmVersion::V2 { + version = SlurmVersion::V2; + } + + merged_filters.extend_from(file.validation_output_filters()); + merged_assertions.extend_from(file.locally_added_assertions()); + } + + Self::new(version, merged_filters, merged_assertions) + } + + fn validate(&self) -> Result<(), SlurmError> { + self.validation_output_filters.validate(self.version)?; + self.locally_added_assertions.validate(self.version)?; + Ok(()) + } +} + +fn validate_cross_file_conflicts(files: &[(String, SlurmFile)]) -> Result<(), SlurmError> { + for i in 0..files.len() { + for j in (i + 1)..files.len() { + let (name_a, file_a) = &files[i]; + let (name_b, file_b) = &files[j]; + + if prefix_spaces_overlap(file_a, file_b) { + return Err(SlurmError::Invalid(format!( + "conflicting SLURM files: '{}' and '{}' have overlapping prefix spaces", + name_a, name_b + ))); + } + + if let Some(asn) = bgpsec_asn_overlap(file_a, file_b) { + return Err(SlurmError::Invalid(format!( + "conflicting SLURM files: '{}' and '{}' both constrain BGPsec ASN {}", + name_a, name_b, asn + ))); + } + + if let Some(customer_asn) = aspa_customer_overlap(file_a, file_b) { + return Err(SlurmError::Invalid(format!( + "conflicting SLURM files: '{}' and '{}' both constrain ASPA customerAsn {}", + name_a, name_b, customer_asn + ))); + } + } + } + + Ok(()) +} + +fn prefix_spaces_overlap(lhs: &SlurmFile, rhs: &SlurmFile) -> bool { + let mut lhs_prefixes = lhs + .validation_output_filters() + .prefix_filters + .iter() + .filter_map(|f| f.prefix.as_ref()) + .chain( + lhs.locally_added_assertions() + .prefix_assertions + .iter() + .map(|a| &a.prefix), + ); + + let rhs_prefixes = rhs + .validation_output_filters() + .prefix_filters + .iter() + .filter_map(|f| f.prefix.as_ref()) + .chain( + rhs.locally_added_assertions() + .prefix_assertions + .iter() + .map(|a| &a.prefix), + ) + .collect::>(); + + lhs_prefixes.any(|left| { + rhs_prefixes + .iter() + .any(|right| prefix_encompasses(left, right) || prefix_encompasses(right, left)) + }) +} + +fn bgpsec_asn_overlap(lhs: &SlurmFile, rhs: &SlurmFile) -> Option { + let lhs_asns = lhs + .validation_output_filters() + .bgpsec_filters + .iter() + .filter_map(|f| f.asn) + .chain( + lhs.locally_added_assertions() + .bgpsec_assertions + .iter() + .map(|a| a.asn), + ) + .collect::>(); + + rhs.validation_output_filters() + .bgpsec_filters + .iter() + .filter_map(|f| f.asn) + .chain( + rhs.locally_added_assertions() + .bgpsec_assertions + .iter() + .map(|a| a.asn), + ) + .find_map(|asn| lhs_asns.contains(&asn).then_some(asn.into_u32())) +} + +fn aspa_customer_overlap(lhs: &SlurmFile, rhs: &SlurmFile) -> Option { + let lhs_customers = lhs + .validation_output_filters() + .aspa_filters + .iter() + .map(|f| f.customer_asn) + .chain( + lhs.locally_added_assertions() + .aspa_assertions + .iter() + .map(|a| a.customer_asn), + ) + .collect::>(); + + rhs.validation_output_filters() + .aspa_filters + .iter() + .map(|f| f.customer_asn) + .chain( + rhs.locally_added_assertions() + .aspa_assertions + .iter() + .map(|a| a.customer_asn), + ) + .find_map(|asn| lhs_customers.contains(&asn).then_some(asn.into_u32())) +} diff --git a/src/slurm/mod.rs b/src/slurm/mod.rs index 0ade6c8..481cd21 100644 --- a/src/slurm/mod.rs +++ b/src/slurm/mod.rs @@ -1 +1,3 @@ -mod slurm; \ No newline at end of file +pub mod file; +pub mod policy; +mod serde; diff --git a/src/slurm/policy.rs b/src/slurm/policy.rs new file mode 100644 index 0000000..e721864 --- /dev/null +++ b/src/slurm/policy.rs @@ -0,0 +1,409 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::str::FromStr; + +use crate::data_model::resources::as_resources::Asn; +use crate::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; +use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Ski}; +use crate::slurm::file::{SlurmError, SlurmVersion}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidationOutputFilters { + pub prefix_filters: Vec, + pub bgpsec_filters: Vec, + pub aspa_filters: Vec, +} + +impl ValidationOutputFilters { + pub(crate) fn extend_from(&mut self, other: &Self) { + self.prefix_filters + .extend(other.prefix_filters.iter().cloned()); + self.bgpsec_filters + .extend(other.bgpsec_filters.iter().cloned()); + self.aspa_filters.extend(other.aspa_filters.iter().cloned()); + } + + pub(crate) fn validate(&self, version: SlurmVersion) -> Result<(), SlurmError> { + if version == SlurmVersion::V1 && !self.aspa_filters.is_empty() { + return Err(SlurmError::Invalid( + "slurmVersion 1 must not contain aspaFilters".to_string(), + )); + } + + for filter in &self.prefix_filters { + filter.validate()?; + } + for filter in &self.bgpsec_filters { + filter.validate()?; + } + for filter in &self.aspa_filters { + filter.validate()?; + } + Ok(()) + } + + pub(crate) fn matches(&self, payload: &Payload) -> bool { + match payload { + Payload::RouteOrigin(route_origin) => self + .prefix_filters + .iter() + .any(|filter| filter.matches(route_origin)), + Payload::RouterKey(router_key) => self + .bgpsec_filters + .iter() + .any(|filter| filter.matches(router_key)), + Payload::Aspa(aspa) => self.aspa_filters.iter().any(|filter| filter.matches(aspa)), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocallyAddedAssertions { + pub prefix_assertions: Vec, + pub bgpsec_assertions: Vec, + pub aspa_assertions: Vec, +} + +impl LocallyAddedAssertions { + pub(crate) fn extend_from(&mut self, other: &Self) { + self.prefix_assertions + .extend(other.prefix_assertions.iter().cloned()); + self.bgpsec_assertions + .extend(other.bgpsec_assertions.iter().cloned()); + self.aspa_assertions + .extend(other.aspa_assertions.iter().cloned()); + } + + pub(crate) fn validate(&self, version: SlurmVersion) -> Result<(), SlurmError> { + if version == SlurmVersion::V1 && !self.aspa_assertions.is_empty() { + return Err(SlurmError::Invalid( + "slurmVersion 1 must not contain aspaAssertions".to_string(), + )); + } + + for assertion in &self.prefix_assertions { + assertion.validate()?; + } + for assertion in &self.bgpsec_assertions { + assertion.validate()?; + } + for assertion in &self.aspa_assertions { + assertion.validate()?; + } + Ok(()) + } + + pub(crate) fn to_payloads(&self) -> Vec { + let mut payloads = Vec::with_capacity( + self.prefix_assertions.len() + + self.bgpsec_assertions.len() + + self.aspa_assertions.len(), + ); + + payloads.extend( + self.prefix_assertions + .iter() + .cloned() + .map(|assertion| Payload::RouteOrigin(assertion.into_route_origin())), + ); + payloads.extend( + self.bgpsec_assertions + .iter() + .cloned() + .map(|assertion| Payload::RouterKey(assertion.into_router_key())), + ); + payloads.extend( + self.aspa_assertions + .iter() + .cloned() + .map(|assertion| Payload::Aspa(assertion.into_aspa())), + ); + + payloads + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrefixFilter { + pub prefix: Option, + pub asn: Option, + pub comment: Option, +} + +impl PrefixFilter { + fn validate(&self) -> Result<(), SlurmError> { + if self.prefix.is_none() && self.asn.is_none() { + return Err(SlurmError::Invalid( + "prefixFilter must contain at least one of prefix or asn".to_string(), + )); + } + Ok(()) + } + + fn matches(&self, route_origin: &RouteOrigin) -> bool { + let prefix_match = self + .prefix + .is_none_or(|filter_prefix| prefix_encompasses(&filter_prefix, route_origin.prefix())); + let asn_match = self.asn.is_none_or(|asn| asn == route_origin.asn()); + prefix_match && asn_match + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BgpsecFilter { + pub asn: Option, + pub ski: Option, + pub comment: Option, +} + +impl BgpsecFilter { + fn validate(&self) -> Result<(), SlurmError> { + if self.asn.is_none() && self.ski.is_none() { + return Err(SlurmError::Invalid( + "bgpsecFilter must contain at least one of asn or SKI".to_string(), + )); + } + Ok(()) + } + + fn matches(&self, router_key: &RouterKey) -> bool { + let asn_match = self.asn.is_none_or(|asn| asn == router_key.asn()); + let ski_match = self.ski.is_none_or(|ski| ski == router_key.ski()); + asn_match && ski_match + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AspaFilter { + pub customer_asn: Asn, + pub comment: Option, +} + +impl AspaFilter { + fn validate(&self) -> Result<(), SlurmError> { + if self.customer_asn.into_u32() == 0 { + return Err(SlurmError::Invalid( + "aspaFilter customerAsn must not be AS0".to_string(), + )); + } + Ok(()) + } + + fn matches(&self, aspa: &Aspa) -> bool { + self.customer_asn == aspa.customer_asn() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrefixAssertion { + pub prefix: IPAddressPrefix, + pub asn: Asn, + pub max_prefix_length: Option, + pub comment: Option, +} + +impl PrefixAssertion { + fn validate(&self) -> Result<(), SlurmError> { + if self.asn.into_u32() == 0 { + return Err(SlurmError::Invalid( + "prefixAssertion asn must not be AS0".to_string(), + )); + } + + if let Some(max_prefix_length) = self.max_prefix_length { + let address_bits = prefix_address_bits(&self.prefix); + if max_prefix_length < self.prefix.prefix_length() { + return Err(SlurmError::Invalid(format!( + "prefixAssertion maxPrefixLength {} must be >= prefix length {}", + max_prefix_length, + self.prefix.prefix_length() + ))); + } + if max_prefix_length > address_bits { + return Err(SlurmError::Invalid(format!( + "prefixAssertion maxPrefixLength {} exceeds address size {}", + max_prefix_length, address_bits + ))); + } + } + + Ok(()) + } + + fn into_route_origin(self) -> RouteOrigin { + RouteOrigin::new( + self.prefix, + self.max_prefix_length + .unwrap_or(self.prefix.prefix_length()), + self.asn, + ) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BgpsecAssertion { + pub asn: Asn, + pub ski: Ski, + pub router_public_key: Vec, + pub comment: Option, +} + +impl BgpsecAssertion { + fn validate(&self) -> Result<(), SlurmError> { + RouterKey::new(self.ski, self.asn, self.router_public_key.clone()) + .validate() + .map_err(|err| SlurmError::Invalid(err.to_string())) + } + + fn into_router_key(self) -> RouterKey { + RouterKey::new(self.ski, self.asn, self.router_public_key) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AspaAssertion { + pub customer_asn: Asn, + pub provider_asns: Vec, + pub comment: Option, +} + +impl AspaAssertion { + fn validate(&self) -> Result<(), SlurmError> { + let providers = self + .provider_asns + .iter() + .map(|asn| asn.into_u32()) + .collect::>(); + + if providers.windows(2).any(|window| window[0] >= window[1]) { + return Err(SlurmError::Invalid( + "aspaAssertion providerAsns must be strictly increasing".to_string(), + )); + } + + let aspa = Aspa::new(self.customer_asn, self.provider_asns.clone()); + aspa.validate_announcement() + .map_err(|err| SlurmError::Invalid(err.to_string()))?; + + Ok(()) + } + + fn into_aspa(self) -> Aspa { + Aspa::new(self.customer_asn, self.provider_asns) + } +} + +pub(crate) fn parse_ip_prefix(input: &str) -> Result { + let (addr, prefix_length) = input + .split_once('/') + .ok_or_else(|| SlurmError::Invalid(format!("invalid prefix '{}'", input)))?; + + let address = IpAddr::from_str(addr.trim()) + .map_err(|err| SlurmError::Invalid(format!("invalid IP address '{}': {}", addr, err)))?; + let prefix_length = prefix_length.trim().parse::().map_err(|err| { + SlurmError::Invalid(format!( + "invalid prefix length '{}': {}", + prefix_length, err + )) + })?; + + match address { + IpAddr::V4(addr) => { + if prefix_length > 32 { + return Err(SlurmError::Invalid(format!( + "IPv4 prefix length {} exceeds 32", + prefix_length + ))); + } + if !is_canonical_v4(addr, prefix_length) { + return Err(SlurmError::Invalid(format!( + "IPv4 prefix '{}' is not canonical", + input + ))); + } + Ok(IPAddressPrefix::new( + IPAddress::from_ipv4(addr), + prefix_length, + )) + } + IpAddr::V6(addr) => { + if prefix_length > 128 { + return Err(SlurmError::Invalid(format!( + "IPv6 prefix length {} exceeds 128", + prefix_length + ))); + } + if !is_canonical_v6(addr, prefix_length) { + return Err(SlurmError::Invalid(format!( + "IPv6 prefix '{}' is not canonical", + input + ))); + } + Ok(IPAddressPrefix::new( + IPAddress::from_ipv6(addr), + prefix_length, + )) + } + } +} + +pub(crate) fn prefix_address_bits(prefix: &IPAddressPrefix) -> u8 { + match prefix.address() { + IPAddress::V4(_) => 32, + IPAddress::V6(_) => 128, + } +} + +pub(crate) fn prefix_encompasses(filter: &IPAddressPrefix, other: &IPAddressPrefix) -> bool { + if filter.afi() != other.afi() { + return false; + } + if filter.prefix_length() > other.prefix_length() { + return false; + } + + match (filter.address(), other.address()) { + (IPAddress::V4(lhs), IPAddress::V4(rhs)) => { + prefix_match_v4(lhs, rhs, filter.prefix_length()) + } + (IPAddress::V6(lhs), IPAddress::V6(rhs)) => { + prefix_match_v6(lhs, rhs, filter.prefix_length()) + } + _ => false, + } +} + +fn prefix_match_v4(lhs: Ipv4Addr, rhs: Ipv4Addr, prefix_length: u8) -> bool { + let mask = if prefix_length == 0 { + 0 + } else { + u32::MAX << (32 - prefix_length) + }; + (u32::from(lhs) & mask) == (u32::from(rhs) & mask) +} + +fn prefix_match_v6(lhs: Ipv6Addr, rhs: Ipv6Addr, prefix_length: u8) -> bool { + let mask = if prefix_length == 0 { + 0 + } else { + u128::MAX << (128 - prefix_length) + }; + (u128::from(lhs) & mask) == (u128::from(rhs) & mask) +} + +fn is_canonical_v4(addr: Ipv4Addr, prefix_length: u8) -> bool { + let mask = if prefix_length == 0 { + 0 + } else { + u32::MAX << (32 - prefix_length) + }; + (u32::from(addr) & !mask) == 0 +} + +fn is_canonical_v6(addr: Ipv6Addr, prefix_length: u8) -> bool { + let mask = if prefix_length == 0 { + 0 + } else { + u128::MAX << (128 - prefix_length) + }; + (u128::from(addr) & !mask) == 0 +} diff --git a/src/slurm/serde.rs b/src/slurm/serde.rs new file mode 100644 index 0000000..2e2001f --- /dev/null +++ b/src/slurm/serde.rs @@ -0,0 +1,313 @@ +use std::io; + +use base64::Engine; +use base64::engine::general_purpose::STANDARD_NO_PAD; +use serde::Deserialize; + +use crate::data_model::resources::as_resources::Asn; +use crate::rtr::payload::Ski; +use crate::slurm::file::{SlurmError, SlurmFile, SlurmVersion}; +use crate::slurm::policy::{ + AspaAssertion, AspaFilter, BgpsecAssertion, BgpsecFilter, LocallyAddedAssertions, + PrefixAssertion, PrefixFilter, ValidationOutputFilters, parse_ip_prefix, +}; + +impl SlurmFile { + pub fn from_slice(input: &[u8]) -> Result { + let version = serde_json::from_slice::(input)?.slurm_version; + match version { + SlurmVersion::V1_U32 => { + let raw = serde_json::from_slice::(input)?; + Self::from_raw_v1(raw) + } + SlurmVersion::V2_U32 => { + let raw = serde_json::from_slice::(input)?; + Self::from_raw_v2(raw) + } + other => Err(SlurmError::Invalid(format!( + "unsupported slurmVersion {}, expected 1 or 2", + other + ))), + } + } + + pub fn from_reader(mut reader: impl io::Read) -> Result { + let mut bytes = Vec::new(); + reader.read_to_end(&mut bytes)?; + Self::from_slice(&bytes) + } + + fn from_raw_v1(raw: RawSlurmFileV1) -> Result { + Self::new( + SlurmVersion::V1, + ValidationOutputFilters { + prefix_filters: raw.validation_output_filters.prefix_filters, + bgpsec_filters: raw.validation_output_filters.bgpsec_filters, + aspa_filters: Vec::new(), + }, + LocallyAddedAssertions { + prefix_assertions: raw.locally_added_assertions.prefix_assertions, + bgpsec_assertions: raw.locally_added_assertions.bgpsec_assertions, + aspa_assertions: Vec::new(), + }, + ) + } + + fn from_raw_v2(raw: RawSlurmFileV2) -> Result { + Self::new( + SlurmVersion::V2, + ValidationOutputFilters { + prefix_filters: raw.validation_output_filters.prefix_filters, + bgpsec_filters: raw.validation_output_filters.bgpsec_filters, + aspa_filters: raw.validation_output_filters.aspa_filters, + }, + LocallyAddedAssertions { + prefix_assertions: raw.locally_added_assertions.prefix_assertions, + bgpsec_assertions: raw.locally_added_assertions.bgpsec_assertions, + aspa_assertions: raw.locally_added_assertions.aspa_assertions, + }, + ) + } +} + +#[derive(Deserialize)] +struct SlurmVersionMarker { + #[serde(rename = "slurmVersion")] + slurm_version: u32, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawSlurmFileV1 { + #[serde(rename = "slurmVersion")] + _slurm_version: u32, + #[serde(rename = "validationOutputFilters")] + validation_output_filters: RawValidationOutputFiltersV1, + #[serde(rename = "locallyAddedAssertions")] + locally_added_assertions: RawLocallyAddedAssertionsV1, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawSlurmFileV2 { + #[serde(rename = "slurmVersion")] + _slurm_version: u32, + #[serde(rename = "validationOutputFilters")] + validation_output_filters: RawValidationOutputFiltersV2, + #[serde(rename = "locallyAddedAssertions")] + locally_added_assertions: RawLocallyAddedAssertionsV2, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawValidationOutputFiltersV1 { + #[serde(rename = "prefixFilters")] + prefix_filters: Vec, + #[serde(rename = "bgpsecFilters")] + bgpsec_filters: Vec, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawValidationOutputFiltersV2 { + #[serde(rename = "prefixFilters")] + prefix_filters: Vec, + #[serde(rename = "bgpsecFilters")] + bgpsec_filters: Vec, + #[serde(rename = "aspaFilters")] + aspa_filters: Vec, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawLocallyAddedAssertionsV1 { + #[serde(rename = "prefixAssertions")] + prefix_assertions: Vec, + #[serde(rename = "bgpsecAssertions")] + bgpsec_assertions: Vec, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawLocallyAddedAssertionsV2 { + #[serde(rename = "prefixAssertions")] + prefix_assertions: Vec, + #[serde(rename = "bgpsecAssertions")] + bgpsec_assertions: Vec, + #[serde(rename = "aspaAssertions")] + aspa_assertions: Vec, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawPrefixFilter { + prefix: Option, + asn: Option, + comment: Option, +} + +impl<'de> Deserialize<'de> for PrefixFilter { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = RawPrefixFilter::deserialize(deserializer)?; + Ok(Self { + prefix: raw + .prefix + .map(|prefix| parse_ip_prefix(&prefix)) + .transpose() + .map_err(serde::de::Error::custom)?, + asn: raw.asn.map(Asn::from), + comment: raw.comment, + }) + } +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawBgpsecFilter { + asn: Option, + #[serde(rename = "SKI")] + ski: Option, + comment: Option, +} + +impl<'de> Deserialize<'de> for BgpsecFilter { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = RawBgpsecFilter::deserialize(deserializer)?; + Ok(Self { + asn: raw.asn.map(Asn::from), + ski: raw + .ski + .map(|ski| decode_ski(&ski)) + .transpose() + .map_err(serde::de::Error::custom)?, + comment: raw.comment, + }) + } +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawAspaFilter { + #[serde(rename = "customerAsn")] + customer_asn: u32, + comment: Option, +} + +impl<'de> Deserialize<'de> for AspaFilter { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = RawAspaFilter::deserialize(deserializer)?; + Ok(Self { + customer_asn: Asn::from(raw.customer_asn), + comment: raw.comment, + }) + } +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawPrefixAssertion { + prefix: String, + asn: u32, + #[serde(rename = "maxPrefixLength")] + max_prefix_length: Option, + comment: Option, +} + +impl<'de> Deserialize<'de> for PrefixAssertion { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = RawPrefixAssertion::deserialize(deserializer)?; + Ok(Self { + prefix: parse_ip_prefix(&raw.prefix).map_err(serde::de::Error::custom)?, + asn: Asn::from(raw.asn), + max_prefix_length: raw.max_prefix_length, + comment: raw.comment, + }) + } +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawBgpsecAssertion { + asn: u32, + #[serde(rename = "SKI")] + ski: String, + #[serde(rename = "routerPublicKey")] + router_public_key: String, + comment: Option, +} + +impl<'de> Deserialize<'de> for BgpsecAssertion { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = RawBgpsecAssertion::deserialize(deserializer)?; + Ok(Self { + asn: Asn::from(raw.asn), + ski: decode_ski(&raw.ski).map_err(serde::de::Error::custom)?, + router_public_key: decode_router_public_key(&raw.router_public_key) + .map_err(serde::de::Error::custom)?, + comment: raw.comment, + }) + } +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +struct RawAspaAssertion { + #[serde(rename = "customerAsn")] + customer_asn: u32, + #[serde(rename = "providerAsns")] + provider_asns: Vec, + comment: Option, +} + +impl<'de> Deserialize<'de> for AspaAssertion { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = RawAspaAssertion::deserialize(deserializer)?; + Ok(Self { + customer_asn: Asn::from(raw.customer_asn), + provider_asns: raw.provider_asns.into_iter().map(Asn::from).collect(), + comment: raw.comment, + }) + } +} + +fn decode_ski(input: &str) -> Result { + let bytes = hex::decode(input) + .map_err(|err| SlurmError::Invalid(format!("invalid SKI '{}': {}", input, err)))?; + if bytes.len() != 20 { + return Err(SlurmError::Invalid(format!( + "SKI must be exactly 20 bytes, got {}", + bytes.len() + ))); + } + + let mut ski = [0u8; 20]; + ski.copy_from_slice(&bytes); + Ok(Ski::from_bytes(ski)) +} + +fn decode_router_public_key(input: &str) -> Result, SlurmError> { + STANDARD_NO_PAD.decode(input).map_err(|err| { + SlurmError::Invalid(format!( + "invalid routerPublicKey base64 '{}': {}", + input, err + )) + }) +} diff --git a/src/slurm/slurm.rs b/src/slurm/slurm.rs deleted file mode 100644 index 7b4b268..0000000 --- a/src/slurm/slurm.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::io; -use crate::data_model::resources::as_resources::Asn; - - -#[derive(Debug, thiserror::Error)] -pub enum SlurmError { - #[error("Read slurm from reader error")] - SlurmFromReader(), -} - - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SlurmFile { - pub version: u32, - pub validation_output_filters: ValidationOutputFilters, - pub locally_added_assertions: LocallyAddedAssertions, -} - -impl SlurmFile { - pub fn new(filters: ValidationOutputFilters, - assertions: LocallyAddedAssertions,) -> Self { - let version = 1; - SlurmFile { - version, - validation_output_filters: filters, - locally_added_assertions: assertions, - } - } - - // pub fn from_reader(reader: impl io::Read)-> Result { - // - // } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ValidationOutputFilters { - pub prefix_filters: Vec, - pub bgpset_filters: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Comment(String); - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PrefixFilter { - pub prefix: String, - pub asn: Asn, - pub comment: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct BgpsecFilter { - pub asn: Asn, - pub ski: u8, - pub comment: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct LocallyAddedAssertions { - pub prefix_assertions: Vec, - pub bgpsec_assertions: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PrefixAssertion { - pub prefix: String, - pub asn: Asn, - pub max_prefix_length: u8, - pub comment: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct BgpsecAssertion { - pub asn: Asn, - pub ski: u8, - pub router_public_key: u8, - pub comment: Option, -} - - diff --git a/src/rtr/ccr.rs b/src/source/ccr.rs similarity index 94% rename from src/rtr/ccr.rs rename to src/source/ccr.rs index 1cc30fb..7ec33bc 100644 --- a/src/rtr/ccr.rs +++ b/src/source/ccr.rs @@ -2,7 +2,7 @@ use std::fs; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::path::{Path, PathBuf}; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result, anyhow}; use der_parser::ber::{BerObject, BerObjectContent}; use der_parser::der::parse_der; @@ -29,8 +29,8 @@ pub struct CcrPayloadConversion { pub fn load_ccr_snapshot_from_file(path: impl AsRef) -> Result { let path = path.as_ref(); - let bytes = fs::read(path) - .with_context(|| format!("failed to read CCR file: {}", path.display()))?; + let bytes = + fs::read(path).with_context(|| format!("failed to read CCR file: {}", path.display()))?; parse_ccr_bytes(&bytes).with_context(|| format!("failed to parse CCR file: {}", path.display())) } @@ -64,9 +64,10 @@ pub fn find_latest_ccr_file(dir: impl AsRef) -> Result { continue; } - if latest.as_ref().is_none_or(|current| { - file_name_key(&path) > file_name_key(current) - }) { + if latest + .as_ref() + .is_none_or(|current| file_name_key(&path) > file_name_key(current)) + { latest = Some(path); } } @@ -250,10 +251,7 @@ fn parse_vaps(field: &BerObject<'_>) -> Result> { Ok(vaps) } -fn parse_roa_address( - address_family: &[u8], - items: &[BerObject<'_>], -) -> Result<(IpAddr, u8, u8)> { +fn parse_roa_address(address_family: &[u8], items: &[BerObject<'_>]) -> Result<(IpAddr, u8, u8)> { let address = items .first() .ok_or_else(|| anyhow!("ROAIPAddress missing address field"))?; @@ -275,8 +273,7 @@ fn parse_roa_address( let max_len = match items.get(1) { Some(value) => { let max_len = as_u32(value, "ROAIPAddress.maxLength")?; - u8::try_from(max_len) - .map_err(|_| anyhow!("maxLength {max_len} does not fit in u8"))? + u8::try_from(max_len).map_err(|_| anyhow!("maxLength {max_len} does not fit in u8"))? } None => prefix_len, }; @@ -328,10 +325,7 @@ fn decode_context_wrapped_sequence<'a>(obj: &'a BerObject<'a>) -> Result, + pub strict_ccr_validation: bool, +} + +pub fn load_payloads_from_latest_sources(config: &PayloadLoadConfig) -> Result> { + let payloads = load_payloads_from_latest_ccr(&config.ccr_dir, config.strict_ccr_validation)?; + + match config.slurm_dir.as_deref() { + Some(dir) => apply_slurm_to_payloads_from_dir(dir, payloads), + None => Ok(payloads), + } +} + +fn load_payloads_from_latest_ccr( + ccr_dir: &str, + strict_ccr_validation: bool, +) -> Result> { + let latest = find_latest_ccr_file(ccr_dir)?; + let snapshot = load_ccr_snapshot_from_file(&latest)?; + let vrp_count = snapshot.vrps.len(); + let vap_count = snapshot.vaps.len(); + let produced_at = snapshot.produced_at.clone(); + let conversion = load_ccr_payloads_from_file_with_options(&latest, strict_ccr_validation)?; + let payloads = conversion.payloads; + + if !conversion.invalid_vrps.is_empty() { + warn!( + "CCR load skipped invalid VRPs: file={}, skipped={}, samples={:?}", + latest.display(), + conversion.invalid_vrps.len(), + sample_messages(&conversion.invalid_vrps) + ); + } + + if !conversion.invalid_vaps.is_empty() { + warn!( + "CCR load skipped invalid VAPs/ASPAs: file={}, skipped={}, samples={:?}", + latest.display(), + conversion.invalid_vaps.len(), + sample_messages(&conversion.invalid_vaps) + ); + } + + info!( + "loaded latest CCR snapshot: file={}, produced_at={:?}, vrp_count={}, vap_count={}, payload_count={}, strict_ccr_validation={}", + latest.display(), + produced_at, + vrp_count, + vap_count, + payloads.len(), + strict_ccr_validation + ); + + Ok(payloads) +} + +fn apply_slurm_to_payloads_from_dir( + slurm_dir: &str, + payloads: Vec, +) -> Result> { + let files = read_slurm_files(slurm_dir)?; + let file_count = files.len(); + let file_names = files + .iter() + .map(|(name, _)| name.clone()) + .collect::>(); + let slurm = SlurmFile::merge_named(files) + .map_err(|err| anyhow!("failed to merge SLURM files from '{}': {}", slurm_dir, err))?; + + let input_count = payloads.len(); + let filtered = slurm.apply(&payloads); + let output_count = filtered.len(); + + info!( + "applied SLURM policy set: slurm_dir={}, file_count={}, files={:?}, merged_slurm_version={}, input_payload_count={}, output_payload_count={}", + slurm_dir, + file_count, + file_names, + slurm.version().as_u32(), + input_count, + output_count + ); + + Ok(filtered) +} + +fn read_slurm_files(slurm_dir: &str) -> Result> { + let mut paths = std::fs::read_dir(slurm_dir) + .map_err(|err| anyhow!("failed to read SLURM directory '{}': {}", slurm_dir, err))? + .filter_map(|entry| entry.ok()) + .map(|entry| entry.path()) + .filter(|path| path.is_file()) + .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("slurm")) + .collect::>(); + + paths.sort_by_key(|path| { + path.file_name() + .and_then(|name| name.to_str()) + .map(|name| name.to_ascii_lowercase()) + .unwrap_or_default() + }); + + if paths.is_empty() { + return Err(anyhow!( + "SLURM directory '{}' does not contain .slurm files", + slurm_dir + )); + } + + paths + .into_iter() + .map(|path| { + let name = path.to_string_lossy().to_string(); + let file = std::fs::File::open(&path) + .map_err(|err| anyhow!("failed to open SLURM file '{}': {}", name, err))?; + let slurm = SlurmFile::from_reader(file) + .map_err(|err| anyhow!("failed to parse SLURM file '{}': {}", name, err))?; + Ok((name, slurm)) + }) + .collect() +} + +fn sample_messages(messages: &[String]) -> Vec<&str> { + messages.iter().take(3).map(String::as_str).collect() +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index a7e5207..e34c08f 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1 +1 @@ -pub mod test_helper; \ No newline at end of file +pub mod test_helper; diff --git a/tests/common/test_helper.rs b/tests/common/test_helper.rs index bfb1676..de1467e 100644 --- a/tests/common/test_helper.rs +++ b/tests/common/test_helper.rs @@ -1,7 +1,7 @@ use std::fmt::Write; use std::net::{Ipv4Addr, Ipv6Addr}; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; use rpki::rtr::cache::SerialResult; @@ -14,7 +14,9 @@ pub struct RtrDebugDumper { impl RtrDebugDumper { pub fn new() -> Self { - Self { entries: Vec::new() } + Self { + entries: Vec::new(), + } } pub fn push(&mut self, pdu: u8, body: &T) { @@ -150,15 +152,7 @@ pub fn v6_prefix(addr: Ipv6Addr, prefix_len: u8) -> IPAddressPrefix { } } -pub fn v4_origin( - a: u8, - b: u8, - c: u8, - d: u8, - prefix_len: u8, - max_len: u8, - asn: u32, -) -> RouteOrigin { +pub fn v4_origin(a: u8, b: u8, c: u8, d: u8, prefix_len: u8, max_len: u8, asn: u32) -> RouteOrigin { let prefix = v4_prefix(a, b, c, d, prefix_len); RouteOrigin::new(prefix, max_len, asn.into()) } @@ -238,7 +232,11 @@ pub fn serial_result_to_string(result: &SerialResult) -> String { } pub fn print_serial_result(label: &str, result: &SerialResult) { - println!("\n===== {} =====\n{}\n", label, serial_result_to_string(result)); + println!( + "\n===== {} =====\n{}\n", + label, + serial_result_to_string(result) + ); } pub fn bytes_to_hex(bytes: &[u8]) -> String { @@ -290,12 +288,8 @@ pub fn snapshot_hashes_to_string(snapshot: &rpki::rtr::cache::Snapshot) -> Strin pub fn serial_result_detail_to_string(result: &rpki::rtr::cache::SerialResult) -> String { match result { - rpki::rtr::cache::SerialResult::UpToDate => { - " result: UpToDate\n".to_string() - } - rpki::rtr::cache::SerialResult::ResetRequired => { - " result: ResetRequired\n".to_string() - } + rpki::rtr::cache::SerialResult::UpToDate => " result: UpToDate\n".to_string(), + rpki::rtr::cache::SerialResult::ResetRequired => " result: ResetRequired\n".to_string(), rpki::rtr::cache::SerialResult::Delta(delta) => { let mut out = String::new(); let _ = writeln!(&mut out, " result: Delta"); @@ -316,4 +310,4 @@ pub fn indent_block(text: &str, spaces: usize) -> String { let _ = writeln!(&mut out, "{}{}", pad, line); } out -} \ No newline at end of file +} diff --git a/tests/test_cache.rs b/tests/test_cache.rs index c82897d..611a7b6 100644 --- a/tests/test_cache.rs +++ b/tests/test_cache.rs @@ -1,4 +1,4 @@ -mod common; +mod common; use std::collections::VecDeque; use std::net::{Ipv4Addr, Ipv6Addr}; @@ -12,8 +12,7 @@ use common::test_helper::{ use rpki::data_model::resources::as_resources::Asn; use rpki::rtr::cache::{ CacheAvailability, Delta, RtrCacheBuilder, SerialResult, SessionIds, Snapshot, - validate_payload_updates_for_rtr, - validate_payloads_for_rtr, + validate_payload_updates_for_rtr, validate_payloads_for_rtr, }; use rpki::rtr::payload::{Aspa, Payload, RouterKey, Ski, Timing}; use rpki::rtr::store::RtrStore; @@ -40,7 +39,11 @@ fn deltas_window_to_string(deltas: &VecDeque>) -> String { out } -fn get_deltas_since_input_to_string(cache_session_id: u16, cache_serial: u32, client_serial: u32) -> String { +fn get_deltas_since_input_to_string( + cache_session_id: u16, + cache_serial: u32, + client_serial: u32, +) -> String { format!( "cache.session_id: {}\ncache.serial: {}\nclient_serial: {}\n", cache_session_id, cache_serial, client_serial @@ -118,7 +121,9 @@ async fn init_keeps_cache_running_when_file_loader_returns_no_data() { let store = RtrStore::open(dir.path()).unwrap(); let cache = rpki::rtr::cache::RtrCache::default() - .init(&store, 16, false, Timing::new(600, 600, 7200), || Ok(vec![])) + .init(&store, 16, false, Timing::new(600, 600, 7200), || { + Ok(vec![]) + }) .unwrap(); assert!(!cache.is_data_available()); @@ -144,12 +149,16 @@ async fn init_restores_wraparound_delta_window_from_store() { ); let d_zero = Delta::new( 0, - vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![Payload::RouteOrigin(v4_origin( + 198, 51, 100, 0, 24, 24, 64497, + ))], vec![], ); let d_one = Delta::new( 1, - vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![Payload::RouteOrigin(v4_origin( + 203, 0, 113, 0, 24, 24, 64498, + ))], vec![], ); @@ -188,7 +197,9 @@ async fn init_restores_wraparound_delta_window_from_store() { .unwrap(); let cache = rpki::rtr::cache::RtrCache::default() - .init(&store, 16, false, Timing::new(600, 600, 7200), || Ok(Vec::new())) + .init(&store, 16, false, Timing::new(600, 600, 7200), || { + Ok(Vec::new()) + }) .unwrap(); match cache.get_deltas_since(u32::MAX.wrapping_sub(1)) { @@ -205,8 +216,8 @@ async fn update_prunes_delta_window_when_cumulative_delta_size_reaches_snapshot_ let dir = tempfile::tempdir().unwrap(); let store = RtrStore::open(dir.path()).unwrap(); let valid_spki = vec![ - 0x30, 0x13, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, - 0x01, 0x01, 0x05, 0x00, 0x03, 0x02, 0x00, 0x00, + 0x30, 0x13, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, + 0x05, 0x00, 0x03, 0x02, 0x00, 0x00, ]; let initial_snapshot = Snapshot::from_payloads(vec![Payload::RouterKey(RouterKey::new( @@ -226,11 +237,14 @@ async fn update_prunes_delta_window_when_cumulative_delta_size_reaches_snapshot_ .build(); cache - .update(vec![Payload::RouterKey(RouterKey::new( - Ski::from_bytes([1u8; 20]), - Asn::from(64496u32), - valid_spki, - ))], &store) + .update( + vec![Payload::RouterKey(RouterKey::new( + Ski::from_bytes([1u8; 20]), + Asn::from(64496u32), + valid_spki, + ))], + &store, + ) .unwrap(); match cache.get_deltas_since(1) { @@ -414,7 +428,10 @@ fn delta_new_sorts_announced_descending_and_withdrawn_ascending() { let w0 = as_v4_route_origin(&delta.withdrawn()[0]); let w1 = as_v4_route_origin(&delta.withdrawn()[1]); - assert_eq!(w0.prefix().address.to_ipv4(), Some(Ipv4Addr::new(10, 0, 0, 0))); + assert_eq!( + w0.prefix().address.to_ipv4(), + Some(Ipv4Addr::new(10, 0, 0, 0)) + ); assert_eq!( w1.prefix().address.to_ipv4(), Some(Ipv4Addr::new(203, 0, 113, 0)) @@ -433,7 +450,8 @@ fn get_deltas_since_returns_up_to_date_when_client_serial_matches_current() { let result = cache.get_deltas_since(100); - let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100); + let input = + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100); let output = serial_result_detail_to_string(&result); test_report( @@ -458,7 +476,9 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old() { )); let d2 = Arc::new(Delta::new( 102, - vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![Payload::RouteOrigin(v4_origin( + 198, 51, 100, 0, 24, 24, 64497, + ))], vec![], )); @@ -506,12 +526,16 @@ fn get_deltas_since_returns_minimal_merged_delta() { )); let d2 = Arc::new(Delta::new( 102, - vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![Payload::RouteOrigin(v4_origin( + 198, 51, 100, 0, 24, 24, 64497, + ))], vec![], )); let d3 = Arc::new(Delta::new( 103, - vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![Payload::RouteOrigin(v4_origin( + 203, 0, 113, 0, 24, 24, 64498, + ))], vec![], )); @@ -584,7 +608,8 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future() { let result = cache.get_deltas_since(101); - let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 101); + let input = + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 101); let output = serial_result_detail_to_string(&result); test_report( @@ -610,11 +635,7 @@ fn get_deltas_since_supports_incremental_updates_across_serial_wraparound() { vec![Payload::RouteOrigin(a.clone())], vec![], )); - let d_zero = Arc::new(Delta::new( - 0, - vec![Payload::RouteOrigin(b.clone())], - vec![], - )); + let d_zero = Arc::new(Delta::new(0, vec![Payload::RouteOrigin(b.clone())], vec![])); let mut deltas = VecDeque::new(); deltas.push_back(d_max); @@ -637,7 +658,11 @@ fn get_deltas_since_supports_incremental_updates_across_serial_wraparound() { let input = format!( "{}delta_window:\n{}", - get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), u32::MAX.wrapping_sub(1)), + get_deltas_since_input_to_string( + cache.session_id_for_version(1), + cache.serial(), + u32::MAX.wrapping_sub(1) + ), indent_block(&deltas_window_to_string(&deltas), 2), ); let output = serial_result_detail_to_string(&result); @@ -678,12 +703,16 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old_across_ )); let d_zero = Arc::new(Delta::new( 0, - vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], + vec![Payload::RouteOrigin(v4_origin( + 198, 51, 100, 0, 24, 24, 64497, + ))], vec![], )); let d_one = Arc::new(Delta::new( 1, - vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], + vec![Payload::RouteOrigin(v4_origin( + 203, 0, 113, 0, 24, 24, 64498, + ))], vec![], )); @@ -709,7 +738,11 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old_across_ let input = format!( "{}delta_window:\n{}", - get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), client_serial), + get_deltas_since_input_to_string( + cache.session_id_for_version(1), + cache.serial(), + client_serial + ), indent_block(&deltas_window_to_string(&deltas), 2), ); let output = serial_result_detail_to_string(&result); @@ -737,7 +770,8 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future_acros let result = cache.get_deltas_since(0); - let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 0); + let input = + get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 0); let output = serial_result_detail_to_string(&result); test_report( @@ -776,10 +810,7 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() { let dir = tempfile::tempdir().unwrap(); let store = RtrStore::open(dir.path()).unwrap(); - let new_payloads = vec![ - Payload::RouteOrigin(old_b), - Payload::RouteOrigin(old_a), - ]; + let new_payloads = vec![Payload::RouteOrigin(old_b), Payload::RouteOrigin(old_a)]; cache.update(new_payloads.clone(), &store).unwrap(); @@ -795,7 +826,10 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() { let output = format!( "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), - indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block( + &snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), + 2 + ), indent_block(&serial_result_detail_to_string(&result), 2), ); @@ -854,7 +888,10 @@ async fn update_add_only_increments_serial_and_generates_announced_delta() { let output = format!( "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), - indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block( + &snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), + 2 + ), indent_block(&serial_result_detail_to_string(&result), 2), ); @@ -921,7 +958,10 @@ async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() { let output = format!( "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), - indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block( + &snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), + 2 + ), indent_block(&serial_result_detail_to_string(&result), 2), ); @@ -997,7 +1037,10 @@ async fn update_add_and_remove_increments_serial_and_generates_both_sides() { let output = format!( "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", cache.serial(), - indent_block(&snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), 2), + indent_block( + &snapshot_hashes_and_sorted_view_to_string(¤t_snapshot), + 2 + ), indent_block(&serial_result_detail_to_string(&result), 2), ); @@ -1282,10 +1325,7 @@ fn get_deltas_since_merges_multiple_deltas_to_final_minimal_view() { #[test] fn snapshot_from_payloads_unions_aspas_by_customer() { - let first = Payload::Aspa(Aspa::new( - Asn::from(64496u32), - vec![Asn::from(64497u32)], - )); + let first = Payload::Aspa(Aspa::new(Asn::from(64496u32), vec![Asn::from(64497u32)])); let second = Payload::Aspa(Aspa::new( Asn::from(64496u32), vec![Asn::from(64498u32), Asn::from(64497u32)], @@ -1369,31 +1409,21 @@ fn validate_payloads_for_rtr_rejects_unsorted_snapshot_payloads() { let high = Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)); let err = validate_payloads_for_rtr(&[low, high], true).unwrap_err(); - assert!(err - .to_string() - .contains("RTR payload ordering violation")); + assert!(err.to_string().contains("RTR payload ordering violation")); } #[test] fn validate_payload_updates_for_rtr_rejects_unsorted_aspa_updates() { let withdraw = ( false, - Payload::Aspa(Aspa::new( - Asn::from(64497u32), - vec![Asn::from(64500u32)], - )), + Payload::Aspa(Aspa::new(Asn::from(64497u32), vec![Asn::from(64500u32)])), ); let announce = ( true, - Payload::Aspa(Aspa::new( - Asn::from(64496u32), - vec![Asn::from(64499u32)], - )), + Payload::Aspa(Aspa::new(Asn::from(64496u32), vec![Asn::from(64499u32)])), ); let err = validate_payload_updates_for_rtr(&[withdraw, announce]).unwrap_err(); assert!(err.to_string().contains("withdraw ASPA")); assert!(err.to_string().contains("announce ASPA")); } - - diff --git a/tests/test_ccr.rs b/tests/test_ccr.rs index 10463e6..da4275d 100644 --- a/tests/test_ccr.rs +++ b/tests/test_ccr.rs @@ -1,14 +1,10 @@ use std::fs; use std::path::PathBuf; -use rpki::rtr::ccr::{ - ParsedCcrSnapshot, - find_latest_ccr_file, - load_ccr_snapshot_from_file, - snapshot_to_payloads_with_options, -}; use rpki::rtr::loader::{ParsedAspa, ParsedVrp}; use tempfile::tempdir; +use rpki::source::ccr::{find_latest_ccr_file, load_ccr_snapshot_from_file, + snapshot_to_payloads_with_options, ParsedCcrSnapshot}; fn fixture_path(name: &str) -> PathBuf { PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("data").join(name) diff --git a/tests/test_slurm.rs b/tests/test_slurm.rs new file mode 100644 index 0000000..2a47a6c --- /dev/null +++ b/tests/test_slurm.rs @@ -0,0 +1,311 @@ +use base64::Engine; +use base64::engine::general_purpose::STANDARD_NO_PAD; + +use rpki::data_model::resources::as_resources::Asn; +use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; +use rpki::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Ski}; +use rpki::slurm::file::{SlurmFile, SlurmVersion}; + +fn sample_spki() -> Vec { + vec![ + 0x30, 0x13, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, + 0x05, 0x00, 0x03, 0x02, 0x00, 0x00, + ] +} + +fn sample_ski() -> [u8; 20] { + [ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, + 0xff, 0x10, 0x20, 0x30, 0x40, + ] +} + +#[test] +fn parses_rfc8416_v1_slurm() { + let ski_hex = hex::encode(sample_ski()); + let router_public_key = STANDARD_NO_PAD.encode(sample_spki()); + let json = format!( + r#"{{ + "slurmVersion": 1, + "validationOutputFilters": {{ + "prefixFilters": [ + {{ "prefix": "192.0.2.0/24", "asn": 64496, "comment": "drop roa" }} + ], + "bgpsecFilters": [ + {{ "asn": 64497, "SKI": "{ski_hex}" }} + ] + }}, + "locallyAddedAssertions": {{ + "prefixAssertions": [ + {{ "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 }} + ], + "bgpsecAssertions": [ + {{ "asn": 64501, "SKI": "{ski_hex}", "routerPublicKey": "{router_public_key}" }} + ] + }} + }}"# + ); + + let slurm = SlurmFile::from_slice(json.as_bytes()).unwrap(); + + assert_eq!(slurm.version(), SlurmVersion::V1); + assert_eq!(slurm.validation_output_filters().prefix_filters.len(), 1); + assert_eq!(slurm.validation_output_filters().bgpsec_filters.len(), 1); + assert!(slurm.validation_output_filters().aspa_filters.is_empty()); + assert_eq!(slurm.locally_added_assertions().prefix_assertions.len(), 1); + assert_eq!(slurm.locally_added_assertions().bgpsec_assertions.len(), 1); + assert!(slurm.locally_added_assertions().aspa_assertions.is_empty()); +} + +#[test] +fn parses_v2_slurm_with_aspa_extensions() { + let json = r#"{ + "slurmVersion": 2, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [], + "aspaFilters": [ + { "customerAsn": 64496 } + ] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [], + "aspaAssertions": [ + { "customerAsn": 64510, "providerAsns": [64511, 64512] } + ] + } + }"#; + + let slurm = SlurmFile::from_slice(json.as_bytes()).unwrap(); + + assert_eq!(slurm.version(), SlurmVersion::V2); + assert_eq!(slurm.validation_output_filters().aspa_filters.len(), 1); + assert_eq!(slurm.locally_added_assertions().aspa_assertions.len(), 1); +} + +#[test] +fn rejects_v1_file_with_aspa_members() { + let json = r#"{ + "slurmVersion": 1, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [], + "aspaFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [] + } + }"#; + + let err = SlurmFile::from_slice(json.as_bytes()).unwrap_err(); + assert!(err.to_string().contains("unknown field")); +} + +#[test] +fn rejects_non_canonical_prefixes_and_unsorted_aspa_providers() { + let non_canonical = r#"{ + "slurmVersion": 1, + "validationOutputFilters": { + "prefixFilters": [ + { "prefix": "192.0.2.1/24" } + ], + "bgpsecFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [] + } + }"#; + let non_canonical_err = SlurmFile::from_slice(non_canonical.as_bytes()).unwrap_err(); + assert!(non_canonical_err.to_string().contains("not canonical")); + + let unsorted_aspa = r#"{ + "slurmVersion": 2, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [], + "aspaFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [], + "aspaAssertions": [ + { "customerAsn": 64500, "providerAsns": [64502, 64501] } + ] + } + }"#; + let aspa_err = SlurmFile::from_slice(unsorted_aspa.as_bytes()).unwrap_err(); + assert!(aspa_err.to_string().contains("strictly increasing")); +} + +#[test] +fn applies_filters_before_assertions_and_excludes_duplicates() { + let ski = Ski::from_bytes(sample_ski()); + let spki = sample_spki(); + let spki_b64 = STANDARD_NO_PAD.encode(&spki); + let ski_hex = hex::encode(sample_ski()); + let json = format!( + r#"{{ + "slurmVersion": 2, + "validationOutputFilters": {{ + "prefixFilters": [ + {{ "prefix": "192.0.2.0/24", "asn": 64496 }} + ], + "bgpsecFilters": [ + {{ "SKI": "{ski_hex}" }} + ], + "aspaFilters": [ + {{ "customerAsn": 64496 }} + ] + }}, + "locallyAddedAssertions": {{ + "prefixAssertions": [ + {{ "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 }}, + {{ "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 }} + ], + "bgpsecAssertions": [ + {{ "asn": 64501, "SKI": "{ski_hex}", "routerPublicKey": "{spki_b64}" }} + ], + "aspaAssertions": [ + {{ "customerAsn": 64510, "providerAsns": [64511, 64512] }} + ] + }} + }}"# + ); + let slurm = SlurmFile::from_slice(json.as_bytes()).unwrap(); + + let input = vec![ + Payload::RouteOrigin(RouteOrigin::new( + IPAddressPrefix::new(IPAddress::from_ipv4("192.0.2.0".parse().unwrap()), 24), + 24, + Asn::from(64496u32), + )), + Payload::RouteOrigin(RouteOrigin::new( + IPAddressPrefix::new(IPAddress::from_ipv4("203.0.113.0".parse().unwrap()), 24), + 24, + Asn::from(64497u32), + )), + Payload::RouterKey(RouterKey::new(ski, Asn::from(64497u32), spki.clone())), + Payload::Aspa(Aspa::new(Asn::from(64496u32), vec![Asn::from(64498u32)])), + ]; + + let output = slurm.apply(&input); + + assert_eq!(output.len(), 4); + assert!(output.iter().any(|payload| matches!( + payload, + Payload::RouteOrigin(route_origin) + if route_origin.prefix().address() == IPAddress::from_ipv4("203.0.113.0".parse().unwrap()) + ))); + assert!(output.iter().any(|payload| matches!( + payload, + Payload::RouteOrigin(route_origin) + if route_origin.prefix().address() == IPAddress::from_ipv4("198.51.100.0".parse().unwrap()) + ))); + assert!(output.iter().any(|payload| matches!( + payload, + Payload::RouterKey(router_key) + if router_key.asn() == Asn::from(64501u32) + ))); + assert!(output.iter().any(|payload| matches!( + payload, + Payload::Aspa(aspa) + if aspa.customer_asn() == Asn::from(64510u32) + ))); +} + +#[test] +fn merges_multiple_slurm_files_without_conflict() { + let a = r#"{ + "slurmVersion": 1, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [ + { "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 } + ], + "bgpsecAssertions": [] + } + }"#; + + let b = r#"{ + "slurmVersion": 2, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [], + "aspaFilters": [ + { "customerAsn": 64510 } + ] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [], + "aspaAssertions": [] + } + }"#; + + let merged = SlurmFile::merge_named(vec![ + ( + "a.slurm".to_string(), + SlurmFile::from_slice(a.as_bytes()).unwrap(), + ), + ( + "b.slurm".to_string(), + SlurmFile::from_slice(b.as_bytes()).unwrap(), + ), + ]) + .unwrap(); + + assert_eq!(merged.version(), SlurmVersion::V2); + assert_eq!(merged.locally_added_assertions().prefix_assertions.len(), 1); + assert_eq!(merged.validation_output_filters().aspa_filters.len(), 1); +} + +#[test] +fn rejects_conflicting_multiple_slurm_files() { + let a = r#"{ + "slurmVersion": 1, + "validationOutputFilters": { + "prefixFilters": [], + "bgpsecFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [ + { "prefix": "10.0.0.0/8", "asn": 64500, "maxPrefixLength": 24 } + ], + "bgpsecAssertions": [] + } + }"#; + + let b = r#"{ + "slurmVersion": 1, + "validationOutputFilters": { + "prefixFilters": [ + { "prefix": "10.0.0.0/16" } + ], + "bgpsecFilters": [] + }, + "locallyAddedAssertions": { + "prefixAssertions": [], + "bgpsecAssertions": [] + } + }"#; + + let err = SlurmFile::merge_named(vec![ + ( + "a.slurm".to_string(), + SlurmFile::from_slice(a.as_bytes()).unwrap(), + ), + ( + "b.slurm".to_string(), + SlurmFile::from_slice(b.as_bytes()).unwrap(), + ), + ]) + .unwrap_err(); + + assert!(err.to_string().contains("conflicting SLURM files")); +}