新增slurm功能

This commit is contained in:
xiuting.xu 2026-04-01 16:24:01 +08:00
parent 8d6173f605
commit cd0330e8ae
49 changed files with 2738 additions and 1052 deletions

10
.dockerignore Normal file
View File

@ -0,0 +1,10 @@
.git
.gitignore
.idea
target
tmp_slurm_output.json
rtr-db
tests
specs
scripts
README.md

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
target/ target/
Cargo.lock Cargo.lock
rtr-db/ rtr-db/
.idea/

View File

@ -21,6 +21,7 @@ rand = "0.10.0"
rocksdb = { version = "0.21.0", default-features = false } rocksdb = { version = "0.21.0", default-features = false }
serde = { version = "1", features = ["derive", "rc"] } serde = { version = "1", features = ["derive", "rc"] }
serde_json = "1" serde_json = "1"
base64 = "0.22"
anyhow = "1" anyhow = "1"
tracing = "0.1.44" tracing = "0.1.44"
sha2 = "0.10" sha2 = "0.10"

Binary file not shown.

Binary file not shown.

23
data/example.slurm Normal file
View File

@ -0,0 +1,23 @@
{
"slurmVersion": 2,
"validationOutputFilters": {
"prefixFilters": [
{
"prefix": "24.0.0.0/8",
"comment": "Filter many VRPs in current CCR sample"
}
],
"bgpsecFilters": [],
"aspaFilters": [
{
"customerAsn": 80,
"comment": "Filter one ASPA known to exist in current CCR sample"
}
]
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": [],
"aspaAssertions": []
}
}

40
deploy/DEPLOYMENT.md Normal file
View File

@ -0,0 +1,40 @@
# Deployment (Supervisor + Docker Compose)
This project runs `src/main.rs` as a long-running server that:
1. loads latest `.ccr` from a configured directory,
2. applies optional SLURM filtering,
3. starts RTR server.
`supervisord` is used as PID 1 in container to keep the process managed and auto-restarted.
## Files
- `deploy/Dockerfile`
- `deploy/supervisord.conf`
- `deploy/docker-compose.yml`
## Runtime Paths in Container
- CCR directory: `/app/data`
- RocksDB directory: `/app/rtr-db`
- SLURM directory: `/app/slurm`
- TLS cert directory (optional): `/app/certs`
## Start
```bash
docker compose -f deploy/docker-compose.yml up -d --build
```
## Stop
```bash
docker compose -f deploy/docker-compose.yml down
```
## Logs
```bash
docker compose -f deploy/docker-compose.yml logs -f rpki-rtr
```

34
deploy/Dockerfile Normal file
View File

@ -0,0 +1,34 @@
FROM rust:1.86-bookworm AS builder
WORKDIR /build
COPY Cargo.toml Cargo.lock ./
COPY src ./src
RUN cargo build --release --bin rpki
FROM debian:bookworm-slim AS runtime
RUN apt-get update \
&& apt-get install -y --no-install-recommends ca-certificates supervisor \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY --from=builder /build/target/release/rpki /usr/local/bin/rpki
COPY deploy/supervisord.conf /etc/supervisor/conf.d/rpki-rtr.conf
RUN mkdir -p /app/data /app/rtr-db /app/certs /app/slurm /var/log/supervisor
ENV RPKI_RTR_ENABLE_TLS=false \
RPKI_RTR_TCP_ADDR=0.0.0.0:323 \
RPKI_RTR_TLS_ADDR=0.0.0.0:324 \
RPKI_RTR_DB_PATH=/app/rtr-db \
RPKI_RTR_CCR_DIR=/app/data \
RPKI_RTR_SLURM_DIR=/app/slurm \
RPKI_RTR_REFRESH_INTERVAL_SECS=300 \
RPKI_RTR_STRICT_CCR_VALIDATION=false
EXPOSE 323 324
CMD ["supervisord", "-n", "-c", "/etc/supervisor/conf.d/rpki-rtr.conf"]

28
deploy/docker-compose.yml Normal file
View File

@ -0,0 +1,28 @@
version: "3.9"
services:
rpki-rtr:
build:
context: ..
dockerfile: deploy/Dockerfile
image: rpki-rtr:latest
container_name: rpki-rtr
restart: unless-stopped
ports:
- "323:323"
- "324:324"
environment:
RPKI_RTR_ENABLE_TLS: "false"
RPKI_RTR_TCP_ADDR: "0.0.0.0:323"
RPKI_RTR_TLS_ADDR: "0.0.0.0:324"
RPKI_RTR_DB_PATH: "/app/rtr-db"
RPKI_RTR_CCR_DIR: "/app/data"
RPKI_RTR_SLURM_DIR: "/app/slurm"
RPKI_RTR_STRICT_CCR_VALIDATION: "false"
RPKI_RTR_REFRESH_INTERVAL_SECS: "300"
volumes:
- ../data:/app/data:ro
- ../rtr-db:/app/rtr-db
- ../data:/app/slurm:ro
# TLS mode example:
# - ../certs:/app/certs:ro

18
deploy/supervisord.conf Normal file
View File

@ -0,0 +1,18 @@
[supervisord]
nodaemon=true
logfile=/dev/null
pidfile=/tmp/supervisord.pid
[program:rpki-rtr]
command=/usr/local/bin/rpki
autostart=true
autorestart=true
startsecs=2
startretries=3
stopsignal=TERM
stopasgroup=true
killasgroup=true
stdout_logfile=/dev/fd/1
stdout_logfile_maxbytes=0
stderr_logfile=/dev/fd/2
stderr_logfile_maxbytes=0

View File

@ -1,34 +1,23 @@
# 10. SLURM(Simplified Local Internet Number Resource Management with the RPKI) # 10. SLURMSimplified Local Internet Number Resource Management with the RPKI
## 10.1 对象定位 ## 10.1 目标与范围
SLURM是一个JSON文件允许 RPKI 依赖方在本地“覆盖/修正/忽略”来自上游RPKI数据的内容而不需要修改或伪造原始RPKI对象。 SLURM 用于让 RPRelying Party在本地对上游 RPKI 验证结果做“过滤”和“补充断言”,而不修改上游发布对象。
## 10.2 数据格式 RFC 8416 §3) 本文档基于:
- RFC 8416SLURM v1ROA/BGPsec
- draft-ietf-sidrops-aspa-slurm-04SLURM v2新增 ASPA
### SLURM ## 10.2 版本与顶层结构
SLURM是一个只包含一个JSON对象的文件。格式要求如下RFC 8416 §3.2 ### 10.2.1 SLURM v1RFC 8416
```text `slurmVersion` 必须为 `1`,且顶层 JSON 对象必须包含且仅包含以下成员:
A SLURM file consists of a single JSON object containing the - `slurmVersion`
following members: - `validationOutputFilters`(必须包含 `prefixFilters``bgpsecFilters`
o A "slurmVersion" member that MUST be set to 1, encoded as a number - `locallyAddedAssertions`(必须包含 `prefixAssertions``bgpsecAssertions`
o A "validationOutputFilters" member (Section 3.3), whose value is
an object. The object MUST contain exactly two members:
* A "prefixFilters" member, whose value is described in
Section 3.3.1.
* A "bgpsecFilters" member, whose value is described in
Section 3.3.2.
o A "locallyAddedAssertions" member (Section 3.4), whose value is an
object. The object MUST contain exactly two members:
* A "prefixAssertions" member, whose value is described in
Section 3.4.1.
* A "bgpsecAssertions" member, whose value is described in
Section 3.4.2.
```
一个空的SLURM json结构体如下 空策略示例:
```json ```json
{ {
@ -44,193 +33,154 @@ following members:
} }
``` ```
### prefixFilters ### 10.2.2 SLURM v2draft-04
其中`prefixFilters`格式要求如下RFC 8416 §3.3.1
`slurmVersion` 必须为 `2`,在 v1 基础上扩展 ASPA 两类成员:
- `validationOutputFilters.aspaFilters`
- `locallyAddedAssertions.aspaAssertions`
空策略示例:
```text
The above is expressed as a value of the "prefixFilters" member, as
an array of zero or more objects. Each object MUST contain either 1)
one of the following members or 2) one of each of the following
members.
o A "prefix" member, whose value is a string representing either an
IPv4 prefix (see Section 3.1 of [RFC4632]) or an IPv6 prefix (see
[RFC5952]).
o An "asn" member, whose value is a number.
In addition, each object MAY contain one optional "comment" member,
whose value is a string.
```
示例:
```json ```json
{
"slurmVersion": 2,
"validationOutputFilters": {
"prefixFilters": [],
"bgpsecFilters": [],
"aspaFilters": []
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": [],
"aspaAssertions": []
}
}
```
## 10.3 字段规范RFC 8416
### 10.3.1 `prefixFilters`
数组元素每项:
- 必须至少包含一个:`prefix``asn`
- 可选:`comment`
匹配规则:
- 若配置了 `prefix`匹配“被该前缀覆盖encompassed”的 VRP 前缀
- 若配置了 `asn`:匹配该 ASN
- 同时配置时:两者都要匹配
### 10.3.2 `bgpsecFilters`
数组元素每项:
- 必须至少包含一个:`asn``SKI`
- 可选:`comment`
匹配规则:
- 按 `asn`/`SKI` 单独或联合匹配 Router KeyBGPsec
### 10.3.3 `prefixAssertions`
数组元素每项:
- 必须:`prefix``asn`
- 可选:`maxPrefixLength``comment`
约束:
- 若给出 `maxPrefixLength`,应满足 `prefix 长度 <= maxPrefixLength <= 地址位宽(IPv4=32, IPv6=128)`
### 10.3.4 `bgpsecAssertions`
数组元素每项:
- 必须:`asn``SKI``routerPublicKey`
- 可选:`comment`
## 10.4 ASPA 扩展draft-ietf-sidrops-aspa-slurm-04
### 10.4.1 `aspaFilters`
数组元素每项:
- 必须:`customerAsn`
- 可选:`comment`
匹配规则:
- 当 VAPValidated ASPA Payload`customerAsn` 等于过滤器 `customerAsn` 时命中并移除。
### 10.4.2 `aspaAssertions`
数组元素每项:
- 必须:`customerAsn`
- 必须:`providerAsns`ASN 数组)
- 可选:`comment`
关键约束draft-04
- `customerAsn` 不得出现在 `providerAsns`
- `providerAsns` 必须按升序排列
- `providerAsns` 里的 ASN 必须唯一(无重复)
语义补充draft-04
- `aspaAssertions` 仅用于“新增断言”,不构成隐式过滤(不会自动替代 `aspaFilters`)。
- 在 RTRv2 输出阶段,新增的 ASPA 断言应加入 ASPA PDU 集合,并做去重。
## 10.5 应用语义RFC 8416 Section 4
### 10.5.1 原子性
SLURM 应用必须是原子的:
- 要么完全不生效(等同未使用 SLURM
- 要么完整按当前 SLURM 配置生效
### 10.5.2 处理顺序
在同一次计算中:
1. 先执行 `validationOutputFilters`(移除匹配验证结果)
2. 再追加 `locallyAddedAssertions`
### 10.5.3 多文件
实现可以支持多个 SLURM 文件并行使用(取并集),但在启用前应检查断言重叠冲突;若存在冲突,整组文件应被拒绝。
## 10.6 最小可用示例SLURM v2
```json
{
"slurmVersion": 2,
"validationOutputFilters": {
"prefixFilters": [ "prefixFilters": [
{ {
"prefix": "192.0.2.0/24", "prefix": "203.0.113.0/24",
"comment": "All VRPs encompassed by prefix" "comment": "Filter a broken VRP from upstream"
}, }
],
"bgpsecFilters": [],
"aspaFilters": [
{ {
"asn": 64496, "customerAsn": 64496,
"comment": "All VRPs matching ASN" "comment": "Filter one customer ASPA"
},
{
"prefix": "198.51.100.0/24",
"asn": 64497,
"comment": "All VRPs encompassed by prefix, matching ASN"
} }
] ]
```
### bgpsecFilters
`bgpsecFilters`格式要求如下RFC 8416 §3.3.2
```text
The above is expressed as a value of the "bgpsecFilters" member, as
an array of zero or more objects. Each object MUST contain one of
either, or one each of both following members:
o An "asn" member, whose value is a number
o An "SKI" member, whose value is the Base64 encoding without
trailing = (Section 5 of [RFC4648]) of the certificates Subject
Key Identifier as described in Section 4.8.2 of [RFC6487]. (This
is the value of the ASN.1 OCTET STRING without the ASN.1 tag or
length fields.)
In addition, each object MAY contain one optional "comment" member,
whose value is a string.
```
示例:
```json
"bgpsecFilters": [
{
"asn": 64496,
"comment": "All keys for ASN"
}, },
{ "locallyAddedAssertions": {
"SKI": "<Base 64 of some SKI>",
"comment": "Key matching Router SKI"
},
{
"asn": 64497,
"SKI": "<Base 64 of some SKI>",
"comment": "Key for ASN 64497 matching Router SKI"
}
]
```
### prefixAssertions
`prefixAssertions`格式要求如下RFC 8416 §3.4.1
```text
The above is expressed as a value of the "prefixAssertions" member,
as an array of zero or more objects. Each object MUST contain one of
each of the following members:
o A "prefix" member, whose value is a string representing either an
IPv4 prefix (see Section 3.1 of [RFC4632]) or an IPv6 prefix (see
[RFC5952]).
o An "asn" member, whose value is a number.
In addition, each object MAY contain one of each of the following
members:
o A "maxPrefixLength" member, whose value is a number.
o A "comment" member, whose value is a string.
```
示例:
```json
"prefixAssertions": [ "prefixAssertions": [
{ {
"asn": 64496, "asn": 64496,
"prefix": "198.51.100.0/24", "prefix": "203.0.113.0/24",
"comment": "My other important route" "maxPrefixLength": 24,
}, "comment": "Local business exception"
}
],
"bgpsecAssertions": [],
"aspaAssertions": [
{ {
"asn": 64496, "customerAsn": 64496,
"prefix": "2001:DB8::/32", "providerAsns": [64497, 64498],
"maxPrefixLength": 48, "comment": "Local ASPA assertion"
"comment": "My other important de-aggregated routes"
} }
] ]
```
### bgpsecAssertions
`bgpsecAssertions`格式要求如下RFC 8416 §3.4.2
```text
The above is expressed as a value of the "bgpsecAssertions" member,
as an array of zero or more objects. Each object MUST contain one
each of all of the following members:
o An "asn" member, whose value is a number.
o An "SKI" member, whose value is the Base64 encoding without
trailing = (Section 5 of [RFC4648]) of the certificates Subject
Key Identifier as described in Section 4.8.2 of [RFC6487] (This is
the value of the ASN.1 OCTET STRING without the ASN.1 tag or
length fields.)
o A "routerPublicKey" member, whose value is the Base64 encoding
without trailing = (Section 5 of [RFC4648]) of the equivalent to
the subjectPublicKeyInfo value of the router certificates public
key, as described in [RFC8208]. This is the full ASN.1 DER
encoding of the subjectPublicKeyInfo, including the ASN.1 tag and
length values of the subjectPublicKeyInfo SEQUENCE.
```
示例:
```json
"bgpsecAssertions": [
{
"asn": 64496,
"SKI": "<some base64 SKI>",
"routerPublicKey": "<some base64 public key>",
"comment": "My known key for my important ASN"
} }
] }
``` ```
## 10.3 抽象数据结构 ## 10.7 参考文献
### SLURM
| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 |
|---------------------------|------------------------|---------|---------|---------------|
| slurm_version | number | SLURM版本 | 版本必须为1 | RFC 8416 §3.2 |
| validation_output_filters | ValidationOutputFilter | 过滤条件 | | |
| locally_added_assertions | LocallyAddedAssertions | 本地添加断言 | | |
### ValidationOutputFilter
| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 |
|----------------|-------------------|-----------|---------|---------------|
| prefix_filters | Vec<PrefixFilter> | 前缀过滤 | 可以为空数组 | RFC 8416 §3.3 |
| bgpsec_filters | Vec<BgpsecFilter> | BGPsec过滤 | 可以为空数组 | RFC 8416 §3.3 |
### LocallyAddedAssertions
| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 |
|-------------------|----------------------|-----------|---------|---------------|
| prefix_assertions | Vec<PrefixAssertion> | 前缀断言 | 可以为空数组 | RFC 8416 §3.4 |
| bgpsec_assertions | Vec<BgpsecAssertion> | BGPsec断言 | 可以为空数组 | RFC 8416 §3.4 |
### PrefixFilter
| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 |
|---------|--------|------|--------------------------------|-----------------|
| prefix | string | 前缀 | IPv4前缀或IPv6前缀prefix和asn至少存在一个 | RFC 8416 §3.3.1 |
| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.3.1 |
| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.3.1 |
### BgpsecFilter
| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 |
|---------|--------|------|------------------|------------------|
| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.3.1 |
| ski | u8 | | 证书的SKI | RFC 8416 §3.3.1 |
| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.3.1 |
### PrefixAssertion
| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 |
|-------------------|--------|--------|---------------|-----------------|
| prefix | string | 前缀 | IPv4前缀或IPv6前缀 | RFC 8416 §3.4.1 |
| asn | number | ASN | | RFC 8416 §3.4.1 |
| max_prefix_length | number | 最大前缀长度 | 可选字段 | RFC 8416 §3.4.1 |
| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.4.1 |
### BgpsecAssertion
| 字段 | 类型 | 语义 | 约束/解析规则 | RFC 引用 |
|-------------------|--------|--------|------------------|-----------------|
| asn | number | ASN | prefix和asn至少存在一个 | RFC 8416 §3.4.2 |
| ski | u8 | | 证书的SKI | RFC 8416 §3.4.2 |
| router_public_key | u8 | 证书的SKI | | RFC 8416 §3.4.2 |
| comment | string | 备注说明 | 可选字段 | RFC 8416 §3.4.2 |
> 注BGPsec部分可以在第一版考虑先留空
## 10.4 规则
- RFC 8416: https://www.rfc-editor.org/rfc/rfc8416.html
- draft-ietf-sidrops-aspa-slurm-04: https://www.ietf.org/archive/id/draft-ietf-sidrops-aspa-slurm-04.html

View File

@ -1,65 +1,158 @@
# 11. RTR (The Resource Public Key Infrastructure (RPKI) to Router Protocol) # 11. RTRRPKI to Router Protocol
## 11.1 Cache Server ## 11.1 目标与文档范围
### 11.1.1 功能需求 RTR 用于把 RP/Cache 已完成密码学验证的 RPKI 数据下发给路由器。
- 支持Full SyncReset Query 本文按以下规范整理:
- 支持Incremental SyncSerial Query - RFC 6810RTR v0
- 支持多客户端并发 - RFC 8210RTR v1更新 RFC 6810
- 支持Serial递增 - draft-ietf-sidrops-8210bis-25RTR v2草案
- 保留一定数量的delta
- 支持原子更新
### 11.1.2 架构设计 ## 11.2 协议演进
采用一级缓存+二级缓存并存的方式。
![img.png](img/img.png) ### 11.2.1 RFC 6810v0
其中,一级缓存为运行时缓存,主要职责: - 只定义 Prefix Origin 相关 payloadIPv4/IPv6 Prefix PDU
- 存储当前完整的snapshot - 主要 PDUSerial Notify / Serial Query / Reset Query / Cache Response / Prefix / End of Data / Cache Reset / Error Report。
- 历史Delta队列管理
- Serial管理
- RTR查询响应
二级缓存为持久化缓存,主要职责: ### 11.2.2 RFC 8210v1
- snapshot持久化
- 缓存重启后的快速恢复snapshot和serial
- 不参与实时查询
- 异步写入
### 11.1.3 核心数据结构设计 在 v0 基础上新增/强化:
- 新增 `Router Key PDU`PDU Type 9v1 可用v0 保留)。
- 强化协议版本协商与降级行为。
- `End of Data` 在 v1 中携带 `Refresh/Retry/Expire` 三个计时参数。
#### 11.1.3.1 总cache ### 11.2.3 Version 2草案
```rust
struct RtrCache {
serial: AtomicU32,
snapshot: ArcSwap<Snapshot>,
deltas: RwLock<VecDeque<Arc<Delta>>>,
max_delta: usize,
}
```
#### 11.1.3.2 Snapshot 在 v1 基础上新增/强化:
```rust - 新增 `ASPA PDU`PDU Type 11仅 v2
struct Snapshot { - 新增 “Races, Ordering, and Transactions” 章节,要求缓存按规定顺序输出 payload 以降低路由器短暂误判。
origins: Vec<RouteOrigin>, - 协议版本提升到 `2`
router_keys: Vec<RouterKey>, - 明确 PDU 最大长度上限为 64k65535
aspas: Vec<Aspa>,
created_at: Instant,
}
```
#### 11.1.3.3 Delta ## 11.3 PDU 与版本矩阵
```rust
struct Delta {
serial: u32,
announced: Vec<Payload>,
withdrawn: Vec<Payload>,
}
```
PDU 类型(按规范注册表):
## 11.2 Transport | PDU Type | 名称 | v0 (RFC6810) | v1 (RFC8210) | v2 (8210bis-25) |
|---|---|---|---|---|
| 0 | Serial Notify | 支持 | 支持 | 支持 |
| 1 | Serial Query | 支持 | 支持 | 支持 |
| 2 | Reset Query | 支持 | 支持 | 支持 |
| 3 | Cache Response | 支持 | 支持 | 支持 |
| 4 | IPv4 Prefix | 支持 | 支持 | 支持 |
| 6 | IPv6 Prefix | 支持 | 支持 | 支持 |
| 7 | End of Data | 支持 | 支持(含计时参数) | 支持 |
| 8 | Cache Reset | 支持 | 支持 | 支持 |
| 9 | Router Key | 保留 | 支持 | 支持 |
| 10 | Error Report | 支持 | 支持 | 支持 |
| 11 | ASPA | 保留 | 保留 | 支持 |
初版实现RTR over TLS(可外网)和RTR over TCP内网两种方式。 通用字段约束:
- `Protocol Version`8-bit。
- `PDU Type`8-bit。
- `Session ID`16-bit。
- `Length`32-bit。
- 保留位zero/reserved发送必须为 0接收时按规范处理。
## 11.4 关键 PDU 语义
### 11.4.1 Serial NotifyType 0
- 由 Cache 主动发送,提示有新序列可拉取。
- 是少数可不由 Router 请求触发的消息。
### 11.4.2 Reset QueryType 2与 Cache ResponseType 3
- Router 启动或失配时发 `Reset Query` 请求全量。
- Cache 回复 `Cache Response`,随后发送全量 payload最后 `End of Data`
### 11.4.3 Serial QueryType 1
- Router 持有上次 `Session ID + Serial` 时请求增量。
- Cache 若可提供增量:返回变化集。
- Cache 若无法从该 serial 补增量:返回 `Cache Reset`,要求 Router 走全量。
### 11.4.4 Prefix / Router Key / ASPA payload
- `IPv4 Prefix`Type 4/ `IPv6 Prefix`Type 6表示 VRP 的 announce/withdraw。
- `Router Key`Type 9v1+):表示 BGPsec Router Key 的 announce/withdraw。
- `ASPA`Type 11v2 草案):表示 ASPA 数据单元的 announce/withdraw。
语义要点v1 / v2 草案):
- 对同一 payload 键(如 Prefix 四元组、Router Key 三元组、ASPA customer 键)应维护清晰的替换/撤销关系。
- Cache 负责把历史变化“合并简化”后再发给 Router避免无意义抖动。
### 11.4.5 End of DataType 7
- 标识一次响应结束,并给出当前 serial。
- v0不含定时器字段。
- v1/v2携带 `Refresh Interval``Retry Interval``Expire Interval`
## 11.5 协议时序
### 11.5.1 初始同步Full Sync
1. Router 建连后发 `Reset Query`(带支持的协议版本)。
2. Cache 回 `Cache Response`
3. Cache 按规范发送 payload 集合。
4. Cache 发 `End of Data` 收尾。
### 11.5.2 增量同步Incremental Sync
1. Router 发 `Serial Query(session_id, serial)`
2. Cache 若可增量,返回变化并以 `End of Data` 收尾。
3. 若不可增量,返回 `Cache Reset`Router 退回 Full Sync。
## 11.6 版本协商与降级
- Router 每次新连接必须由 `Reset Query``Serial Query` 启动,携带其协议版本。
- 双方在协商完成后,本连接内版本固定。
- 遇到不支持版本时,可按规范降级(例如 v1 对 v0、v2 对 v1/v0或返回 `Unsupported Protocol Version` 后断开。
- 协商期若收到 `Serial Notify`Router 应按规范兼容处理(通常忽略,待协商完成)。
## 11.7 计时器与失效v1/v2
`End of Data` 下发三个参数:
- `Refresh Interval`:多久后主动刷新。
- `Retry Interval`:失败后重试间隔。
- `Expire Interval`:本地数据最长可保留时长。
规范边界RFC 8210
- Refresh: 1 .. 86400推荐 3600
- Retry: 1 .. 7200推荐 600
- Expire: 600 .. 172800推荐 7200
- 且 `Expire` 必须大于 `Refresh``Retry`
## 11.8 Version 2草案新增关注点
### 11.8.1 ASPA PDU
- 新增 ASPA 传输能力Type 11
- 针对同一 customer ASNCache 需向 Router 提供一致且可替换的 ASPA 视图。
### 11.8.2 排序与事务
- 草案新增 race 条件说明(如前缀替换、撤销先后导致短暂误判)。
- 对 Cache 输出 payload 的顺序提出约束。
- 建议 Router 使用“事务式应用”(例如接收到完整响应后再切换生效)降低中间态影响。
## 11.9 传输与安全
规范定义可承载于多种传输:
- SSH
- TLS
- TCP MD5
- TCP-AO
安全原则:
- Router 与 Cache 之间必须建立可信关系。
- 需要完整性/机密性时优先使用具备认证与加密能力的传输。
- 若使用普通 TCP部署上应限制在可信受控网络中。
## 11.10 参考文献
- RFC 6810: https://www.rfc-editor.org/rfc/rfc6810.html
- RFC 8210: https://www.rfc-editor.org/rfc/rfc8210.html
- draft-ietf-sidrops-8210bis-25: https://www.ietf.org/archive/id/draft-ietf-sidrops-8210bis-25.html

View File

@ -7,18 +7,16 @@ use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio::io::{self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, WriteHalf}; use tokio::io::{self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::{timeout, Duration, Instant}; use tokio::time::{Duration, Instant, timeout};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
mod wire;
mod pretty; mod pretty;
mod protocol; mod protocol;
mod wire;
use crate::wire::{read_pdu, send_reset_query, send_serial_query}; use crate::pretty::{parse_end_of_data_info, parse_serial_notify_serial, print_pdu, print_raw_pdu};
use crate::pretty::{
parse_end_of_data_info, parse_serial_notify_serial, print_pdu, print_raw_pdu,
};
use crate::protocol::{PduHeader, PduType, QueryMode}; use crate::protocol::{PduHeader, PduType, QueryMode};
use crate::wire::{read_pdu, send_reset_query, send_serial_query};
const DEFAULT_READ_TIMEOUT_SECS: u64 = 30; const DEFAULT_READ_TIMEOUT_SECS: u64 = 30;
const DEFAULT_POLL_INTERVAL_SECS: u64 = 600; const DEFAULT_POLL_INTERVAL_SECS: u64 = 600;
@ -38,7 +36,10 @@ async fn main() -> io::Result<()> {
println!("transport: {}", config.transport.describe()); println!("transport: {}", config.transport.describe());
println!("version : {}", config.version); println!("version : {}", config.version);
println!("timeout : {}s", config.read_timeout_secs); println!("timeout : {}s", config.read_timeout_secs);
println!("poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs); println!(
"poll : {}s (default before EndOfData refresh is known)",
config.default_poll_secs
);
println!("keep-after-error: {}", config.keep_after_error); println!("keep-after-error: {}", config.keep_after_error);
match &config.mode { match &config.mode {
QueryMode::Reset => { QueryMode::Reset => {
@ -72,11 +73,7 @@ async fn main() -> io::Result<()> {
} }
Err(err) => { Err(err) => {
let delay = state.reconnect_delay_secs(); let delay = state.reconnect_delay_secs();
eprintln!( eprintln!("connect failed: {}. retry after {}s", err, delay);
"connect failed: {}. retry after {}s",
err,
delay
);
tokio::time::sleep(Duration::from_secs(delay)).await; tokio::time::sleep(Duration::from_secs(delay)).await;
} }
} }
@ -171,10 +168,7 @@ async fn main() -> io::Result<()> {
if reconnect { if reconnect {
let delay = state.reconnect_delay_secs(); let delay = state.reconnect_delay_secs();
state.current_session_id = None; state.current_session_id = None;
println!( println!("[reconnect] transport disconnected, retry after {}s", delay);
"[reconnect] transport disconnected, retry after {}s",
delay
);
tokio::time::sleep(Duration::from_secs(delay)).await; tokio::time::sleep(Duration::from_secs(delay)).await;
} }
} }
@ -189,8 +183,7 @@ async fn send_resume_query(
(Some(session_id), Some(serial)) => { (Some(session_id), Some(serial)) => {
println!( println!(
"reconnected, send Serial Query with session_id={}, serial={}", "reconnected, send Serial Query with session_id={}, serial={}",
session_id, session_id, serial
serial
); );
send_serial_query(writer, state.version, session_id, serial).await?; send_serial_query(writer, state.version, session_id, serial).await?;
} }
@ -294,8 +287,7 @@ async fn handle_incoming_pdu(
println!(); println!();
println!( println!(
"[notify] received Serial Notify: session_id={}, notify_serial={:?}", "[notify] received Serial Notify: session_id={}, notify_serial={:?}",
notify_session_id, notify_session_id, notify_serial
notify_serial
); );
match (state.session_id, state.serial, notify_serial) { match (state.session_id, state.serial, notify_serial) {
@ -306,12 +298,7 @@ async fn handle_incoming_pdu(
"received Serial Notify for current session {}, send Serial Query with serial {}", "received Serial Notify for current session {}, send Serial Query with serial {}",
current_session_id, current_serial current_session_id, current_serial
); );
send_serial_query( send_serial_query(writer, state.version, current_session_id, current_serial)
writer,
state.version,
current_session_id,
current_serial,
)
.await?; .await?;
} }
@ -366,10 +353,7 @@ async fn handle_incoming_pdu(
Ok(()) Ok(())
} }
async fn handle_poll_tick( async fn handle_poll_tick(writer: &mut ClientWriter, state: &mut ClientState) -> io::Result<()> {
writer: &mut ClientWriter,
state: &mut ClientState,
) -> io::Result<()> {
println!(); println!();
println!( println!(
"[auto-poll] timer fired (interval={}s)", "[auto-poll] timer fired (interval={}s)",
@ -422,8 +406,7 @@ async fn handle_console_command(
state.schedule_next_poll(); state.schedule_next_poll();
} }
["serial"] => { ["serial"] => match (state.session_id, state.serial) {
match (state.session_id, state.serial) {
(Some(session_id), Some(serial)) => { (Some(session_id), Some(serial)) => {
println!( println!(
"manual command: send Serial Query with current state: session_id={}, serial={}", "manual command: send Serial Query with current state: session_id={}, serial={}",
@ -437,8 +420,7 @@ async fn handle_console_command(
"manual command failed: current session_id/serial not available, use `reset` or `serial <session_id> <serial>`" "manual command failed: current session_id/serial not available, use `reset` or `serial <session_id> <serial>`"
); );
} }
} },
}
["serial", session_id, serial] => { ["serial", session_id, serial] => {
let session_id = match session_id.parse::<u16>() { let session_id = match session_id.parse::<u16>() {
@ -493,7 +475,10 @@ async fn handle_console_command(
"current effective poll interval: {}s", "current effective poll interval: {}s",
state.effective_poll_secs() state.effective_poll_secs()
); );
println!("poll interval source : {}", state.poll_interval_source()); println!(
"poll interval source : {}",
state.poll_interval_source()
);
println!("stored refresh hint : {:?}", state.refresh); println!("stored refresh hint : {:?}", state.refresh);
println!("default poll interval : {}s", state.default_poll_secs); println!("default poll interval : {}s", state.default_poll_secs);
println!("last_error_code : {:?}", state.last_error_code); println!("last_error_code : {:?}", state.last_error_code);
@ -626,17 +611,20 @@ impl ClientState {
fn effective_poll_secs(&self) -> u64 { fn effective_poll_secs(&self) -> u64 {
if self.should_prefer_retry_poll() { if self.should_prefer_retry_poll() {
self.retry self.retry.map(|v| v as u64).unwrap_or_else(|| {
self.refresh
.map(|v| v as u64) .map(|v| v as u64)
.unwrap_or_else(|| self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs)) .unwrap_or(self.default_poll_secs)
})
} else { } else {
self.refresh.map(|v| v as u64).unwrap_or(self.default_poll_secs) self.refresh
.map(|v| v as u64)
.unwrap_or(self.default_poll_secs)
} }
} }
fn schedule_next_poll(&mut self) { fn schedule_next_poll(&mut self) {
self.next_poll_deadline = self.next_poll_deadline = Instant::now() + Duration::from_secs(self.effective_poll_secs());
Instant::now() + Duration::from_secs(self.effective_poll_secs());
} }
fn pause_auto_poll(&mut self) { fn pause_auto_poll(&mut self) {
@ -728,7 +716,10 @@ impl Config {
} }
"--server-name" => { "--server-name" => {
let name = args.next().ok_or_else(|| { let name = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--server-name requires a value") io::Error::new(
io::ErrorKind::InvalidInput,
"--server-name requires a value",
)
})?; })?;
ensure_tls_config(&mut transport)?.server_name = Some(name); ensure_tls_config(&mut transport)?.server_name = Some(name);
} }
@ -805,10 +796,7 @@ impl Config {
let serial = positional let serial = positional
.next() .next()
.ok_or_else(|| { .ok_or_else(|| {
io::Error::new( io::Error::new(io::ErrorKind::InvalidInput, "serial mode requires serial")
io::ErrorKind::InvalidInput,
"serial mode requires serial",
)
})? })?
.parse::<u32>() .parse::<u32>()
.map_err(|e| { .map_err(|e| {
@ -949,8 +937,14 @@ async fn connect_tls_stream(addr: &str, tls: &TlsConfig) -> io::Result<DynStream
format!("invalid TLS server name '{}': {}", server_name_str, err), format!("invalid TLS server name '{}': {}", server_name_str, err),
) )
})?; })?;
let tls_stream = connector.connect(server_name, stream).await.map_err(|err| { let tls_stream = connector
io::Error::new(io::ErrorKind::ConnectionAborted, format!("TLS handshake failed: {}", err)) .connect(server_name, stream)
.await
.map_err(|err| {
io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("TLS handshake failed: {}", err),
)
})?; })?;
Ok(Box::new(tls_stream)) Ok(Box::new(tls_stream))
} }
@ -966,7 +960,10 @@ fn build_tls_connector(tls: &TlsConfig) -> io::Result<TlsConnector> {
if added == 0 { if added == 0 {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
format!("no valid CA certificates found in {}", ca_cert_path.display()), format!(
"no valid CA certificates found in {}",
ca_cert_path.display()
),
)); ));
} }

View File

@ -1,9 +1,8 @@
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use crate::protocol::{ use crate::protocol::{
flag_meaning, hex_bytes, PduHeader, PduType, ASPA_FIXED_BODY_LEN, ASPA_FIXED_BODY_LEN, END_OF_DATA_V0_BODY_LEN, END_OF_DATA_V1_BODY_LEN, IPV4_PREFIX_BODY_LEN,
END_OF_DATA_V0_BODY_LEN, END_OF_DATA_V1_BODY_LEN, IPV4_PREFIX_BODY_LEN, IPV6_PREFIX_BODY_LEN, PduHeader, PduType, ROUTER_KEY_FIXED_BODY_LEN, flag_meaning, hex_bytes,
IPV6_PREFIX_BODY_LEN, ROUTER_KEY_FIXED_BODY_LEN,
}; };
pub fn print_pdu(header: &PduHeader, body: &[u8]) { pub fn print_pdu(header: &PduHeader, body: &[u8]) {
@ -143,8 +142,7 @@ fn print_error_report(header: &PduHeader, body: &[u8]) {
return; return;
} }
let encapsulated_len = let encapsulated_len = u32::from_be_bytes([body[0], body[1], body[2], body[3]]) as usize;
u32::from_be_bytes([body[0], body[1], body[2], body[3]]) as usize;
if body.len() < 4 + encapsulated_len + 4 { if body.len() < 4 + encapsulated_len + 4 {
println!("invalid ErrorReport: truncated encapsulated PDU"); println!("invalid ErrorReport: truncated encapsulated PDU");

View File

@ -2,9 +2,7 @@ use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::protocol::{ use crate::protocol::{HEADER_LEN, MAX_PDU_LEN, PduHeader, PduType, RawPdu, SERIAL_QUERY_LEN};
PduHeader, PduType, RawPdu, HEADER_LEN, MAX_PDU_LEN, SERIAL_QUERY_LEN,
};
pub async fn send_reset_query<S>(stream: &mut S, version: u8) -> io::Result<()> pub async fn send_reset_query<S>(stream: &mut S, version: u8) -> io::Result<()>
where where
@ -56,10 +54,7 @@ where
if header.length < HEADER_LEN as u32 { if header.length < HEADER_LEN as u32 {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
format!( format!("invalid PDU length {} < {}", header.length, HEADER_LEN),
"invalid PDU length {} < {}",
header.length, HEADER_LEN
),
)); ));
} }

View File

@ -0,0 +1,168 @@
use std::env;
use std::fs::File;
use std::path::PathBuf;
use anyhow::{Context, Result, anyhow};
use serde::Serialize;
use rpki::rtr::payload::Payload;
use rpki::slurm::file::SlurmFile;
use rpki::source::ccr::{
find_latest_ccr_file, load_ccr_snapshot_from_file, snapshot_to_payloads_with_options,
};
#[derive(Debug)]
struct Cli {
ccr_path: PathBuf,
slurm_path: PathBuf,
strict_ccr: bool,
dump_payloads: bool,
}
#[derive(Debug, Serialize)]
struct Output {
ccr_path: String,
slurm_path: String,
produced_at: Option<String>,
slurm_version: u32,
input_payload_count: usize,
input_vrp_count: usize,
input_vap_count: usize,
output_payload_count: usize,
output_vrp_count: usize,
output_vap_count: usize,
invalid_vrps: Vec<String>,
invalid_vaps: Vec<String>,
sample_output_aspa_customers: Vec<u32>,
payloads: Option<Vec<Payload>>,
}
fn main() -> Result<()> {
let cli = parse_args(env::args().skip(1))?;
let snapshot = load_ccr_snapshot_from_file(&cli.ccr_path)
.with_context(|| format!("failed to load CCR snapshot: {}", cli.ccr_path.display()))?;
let slurm = load_slurm(&cli.slurm_path)?;
let conversion = snapshot_to_payloads_with_options(&snapshot, cli.strict_ccr)?;
let payloads = slurm.apply(&conversion.payloads);
let (input_vrp_count, input_vap_count) = count_vrps_and_vaps(&conversion.payloads);
let (output_vrp_count, output_vap_count) = count_vrps_and_vaps(&payloads);
let output = Output {
ccr_path: cli.ccr_path.display().to_string(),
slurm_path: cli.slurm_path.display().to_string(),
produced_at: snapshot.produced_at.clone(),
slurm_version: slurm.version().as_u32(),
input_payload_count: conversion.payloads.len(),
input_vrp_count,
input_vap_count,
output_payload_count: payloads.len(),
output_vrp_count,
output_vap_count,
invalid_vrps: conversion.invalid_vrps,
invalid_vaps: conversion.invalid_vaps,
sample_output_aspa_customers: sample_aspa_customers(&payloads, 8),
payloads: cli.dump_payloads.then_some(payloads),
};
println!("{}", serde_json::to_string_pretty(&output)?);
Ok(())
}
fn load_slurm(path: &PathBuf) -> Result<SlurmFile> {
let file = File::open(path)
.with_context(|| format!("failed to open SLURM file: {}", path.display()))?;
SlurmFile::from_reader(file)
.with_context(|| format!("failed to parse SLURM file: {}", path.display()))
}
fn parse_args(args: impl Iterator<Item = String>) -> Result<Cli> {
let mut strict_ccr = false;
let mut dump_payloads = false;
let mut positionals = Vec::new();
for arg in args {
match arg.as_str() {
"--strict-ccr" => strict_ccr = true,
"--dump-payloads" => dump_payloads = true,
"-h" | "--help" => {
print_help();
std::process::exit(0);
}
_ if arg.starts_with('-') => {
return Err(anyhow!("unknown option: {}", arg));
}
_ => positionals.push(arg),
}
}
if positionals.is_empty() {
return Ok(Cli {
ccr_path: find_latest_ccr_file("data")
.context("failed to find latest .ccr in ./data for default run")?,
slurm_path: PathBuf::from("data/example.slurm"),
strict_ccr,
dump_payloads,
});
}
if positionals.len() != 2 {
print_help();
return Err(anyhow!(
"expected: slurm_apply_client <snapshot.ccr> <policy.slurm>"
));
}
Ok(Cli {
ccr_path: PathBuf::from(&positionals[0]),
slurm_path: PathBuf::from(&positionals[1]),
strict_ccr,
dump_payloads,
})
}
fn print_help() {
eprintln!(
"Usage: cargo run --bin slurm_apply_client -- [--strict-ccr] [--dump-payloads] <snapshot.ccr> <policy.slurm>"
);
eprintln!();
eprintln!("Reads a CCR snapshot, converts it into payloads, applies SLURM, and prints JSON.");
eprintln!(
"When no arguments are provided, it defaults to the latest .ccr under ./data and ./data/example.slurm."
);
eprintln!("Use --dump-payloads to include the full payload list in the JSON output.");
}
fn count_vrps_and_vaps(payloads: &[Payload]) -> (usize, usize) {
let mut vrps = 0;
let mut vaps = 0;
for payload in payloads {
match payload {
Payload::RouteOrigin(_) => vrps += 1,
Payload::Aspa(_) => vaps += 1,
Payload::RouterKey(_) => {}
}
}
(vrps, vaps)
}
fn sample_aspa_customers(payloads: &[Payload], limit: usize) -> Vec<u32> {
let mut customers = Vec::new();
for payload in payloads {
if let Payload::Aspa(aspa) = payload {
let customer = aspa.customer_asn().into_u32();
if !customers.contains(&customer) {
customers.push(customer);
if customers.len() == limit {
break;
}
}
}
}
customers
}

View File

@ -1,3 +1,4 @@
pub mod data_model; pub mod data_model;
mod slurm; pub mod slurm;
pub mod rtr; pub mod rtr;
pub mod source;

View File

@ -3,15 +3,15 @@ use std::net::SocketAddr;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use anyhow::{anyhow, Result}; use anyhow::{Result, anyhow};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tracing::{info, warn}; use tracing::{info, warn};
use rpki::rtr::ccr::{find_latest_ccr_file, load_ccr_payloads_from_file_with_options, load_ccr_snapshot_from_file};
use rpki::rtr::cache::{RtrCache, SharedRtrCache}; use rpki::rtr::cache::{RtrCache, SharedRtrCache};
use rpki::rtr::payload::Timing; use rpki::rtr::payload::Timing;
use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService}; use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService};
use rpki::rtr::store::RtrStore; use rpki::rtr::store::RtrStore;
use rpki::source::pipeline::{PayloadLoadConfig, load_payloads_from_latest_sources};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct AppConfig { struct AppConfig {
@ -21,6 +21,7 @@ struct AppConfig {
db_path: String, db_path: String,
ccr_dir: String, ccr_dir: String,
slurm_dir: Option<String>,
tls_cert_path: String, tls_cert_path: String,
tls_key_path: String, tls_key_path: String,
tls_client_ca_path: String, tls_client_ca_path: String,
@ -42,6 +43,7 @@ impl Default for AppConfig {
db_path: "./rtr-db".to_string(), db_path: "./rtr-db".to_string(),
ccr_dir: "./data".to_string(), ccr_dir: "./data".to_string(),
slurm_dir: None,
tls_cert_path: "./certs/server.crt".to_string(), tls_cert_path: "./certs/server.crt".to_string(),
tls_key_path: "./certs/server.key".to_string(), tls_key_path: "./certs/server.key".to_string(),
tls_client_ca_path: "./certs/client-ca.crt".to_string(), tls_client_ca_path: "./certs/client-ca.crt".to_string(),
@ -85,6 +87,14 @@ impl AppConfig {
if let Some(value) = env_var("RPKI_RTR_CCR_DIR")? { if let Some(value) = env_var("RPKI_RTR_CCR_DIR")? {
config.ccr_dir = value; config.ccr_dir = value;
} }
if let Some(value) = env_var("RPKI_RTR_SLURM_DIR")? {
let value = value.trim();
config.slurm_dir = if value.is_empty() {
None
} else {
Some(value.to_string())
};
}
if let Some(value) = env_var("RPKI_RTR_TLS_CERT_PATH")? { if let Some(value) = env_var("RPKI_RTR_TLS_CERT_PATH")? {
config.tls_cert_path = value; config.tls_cert_path = value;
} }
@ -104,8 +114,7 @@ impl AppConfig {
parse_bool(&value, "RPKI_RTR_PRUNE_DELTA_BY_SNAPSHOT_SIZE")?; parse_bool(&value, "RPKI_RTR_PRUNE_DELTA_BY_SNAPSHOT_SIZE")?;
} }
if let Some(value) = env_var("RPKI_RTR_STRICT_CCR_VALIDATION")? { if let Some(value) = env_var("RPKI_RTR_STRICT_CCR_VALIDATION")? {
config.strict_ccr_validation = config.strict_ccr_validation = parse_bool(&value, "RPKI_RTR_STRICT_CCR_VALIDATION")?;
parse_bool(&value, "RPKI_RTR_STRICT_CCR_VALIDATION")?;
} }
if let Some(value) = env_var("RPKI_RTR_REFRESH_INTERVAL_SECS")? { if let Some(value) = env_var("RPKI_RTR_REFRESH_INTERVAL_SECS")? {
let secs: u64 = value.parse().map_err(|err| { let secs: u64 = value.parse().map_err(|err| {
@ -118,9 +127,9 @@ impl AppConfig {
config.refresh_interval = Duration::from_secs(secs); config.refresh_interval = Duration::from_secs(secs);
} }
if let Some(value) = env_var("RPKI_RTR_MAX_CONNECTIONS")? { if let Some(value) = env_var("RPKI_RTR_MAX_CONNECTIONS")? {
config.service_config.max_connections = value.parse().map_err(|err| { config.service_config.max_connections = value
anyhow!("invalid RPKI_RTR_MAX_CONNECTIONS '{}': {}", value, err) .parse()
})?; .map_err(|err| anyhow!("invalid RPKI_RTR_MAX_CONNECTIONS '{}': {}", value, err))?;
} }
if let Some(value) = env_var("RPKI_RTR_NOTIFY_QUEUE_SIZE")? { if let Some(value) = env_var("RPKI_RTR_NOTIFY_QUEUE_SIZE")? {
config.service_config.notify_queue_size = value.parse().map_err(|err| { config.service_config.notify_queue_size = value.parse().map_err(|err| {
@ -184,12 +193,17 @@ fn open_store(config: &AppConfig) -> Result<RtrStore> {
} }
fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result<SharedRtrCache> { fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result<SharedRtrCache> {
let payload_load_config = PayloadLoadConfig {
ccr_dir: config.ccr_dir.clone(),
slurm_dir: config.slurm_dir.clone(),
strict_ccr_validation: config.strict_ccr_validation,
};
let initial_cache = RtrCache::default().init( let initial_cache = RtrCache::default().init(
store, store,
config.max_delta, config.max_delta,
config.prune_delta_by_snapshot_size, config.prune_delta_by_snapshot_size,
Timing::default(), Timing::default(),
|| load_payloads_from_latest_ccr(&config.ccr_dir, config.strict_ccr_validation), || load_payloads_from_latest_sources(&payload_load_config),
)?; )?;
let shared_cache: SharedRtrCache = Arc::new(RwLock::new(initial_cache)); let shared_cache: SharedRtrCache = Arc::new(RwLock::new(initial_cache));
@ -232,8 +246,11 @@ fn spawn_refresh_task(
notifier: RtrNotifier, notifier: RtrNotifier,
) -> JoinHandle<()> { ) -> JoinHandle<()> {
let refresh_interval = config.refresh_interval; let refresh_interval = config.refresh_interval;
let ccr_dir = config.ccr_dir.clone(); let payload_load_config = PayloadLoadConfig {
let strict_ccr_validation = config.strict_ccr_validation; ccr_dir: config.ccr_dir.clone(),
slurm_dir: config.slurm_dir.clone(),
strict_ccr_validation: config.strict_ccr_validation,
};
tokio::spawn(async move { tokio::spawn(async move {
let mut interval = tokio::time::interval(refresh_interval); let mut interval = tokio::time::interval(refresh_interval);
@ -241,7 +258,7 @@ fn spawn_refresh_task(
loop { loop {
interval.tick().await; interval.tick().await;
match load_payloads_from_latest_ccr(&ccr_dir, strict_ccr_validation) { match load_payloads_from_latest_sources(&payload_load_config) {
Ok(payloads) => { Ok(payloads) => {
let payload_count = payloads.len(); let payload_count = payloads.len();
let updated = { let updated = {
@ -261,7 +278,7 @@ fn spawn_refresh_task(
if new_serial != old_serial { if new_serial != old_serial {
info!( info!(
"RTR cache refresh applied: ccr_dir={}, payload_count={}, old_serial={}, new_serial={}", "RTR cache refresh applied: ccr_dir={}, payload_count={}, old_serial={}, new_serial={}",
ccr_dir, payload_load_config.ccr_dir,
payload_count, payload_count,
old_serial, old_serial,
new_serial new_serial
@ -270,9 +287,7 @@ fn spawn_refresh_task(
} else { } else {
info!( info!(
"RTR cache refresh found no change: ccr_dir={}, payload_count={}, serial={}", "RTR cache refresh found no change: ccr_dir={}, payload_count={}, serial={}",
ccr_dir, payload_load_config.ccr_dir, payload_count, old_serial
payload_count,
old_serial
); );
false false
} }
@ -290,7 +305,10 @@ fn spawn_refresh_task(
} }
} }
Err(err) => { Err(err) => {
warn!("failed to reload CCR payloads from {}: {:?}", ccr_dir, err); warn!(
"failed to reload CCR/SLURM payloads from {}: {:?}",
payload_load_config.ccr_dir, err
);
} }
} }
} }
@ -317,16 +335,17 @@ fn log_startup_config(config: &AppConfig) {
} }
info!("ccr_dir={}", config.ccr_dir); info!("ccr_dir={}", config.ccr_dir);
info!(
"slurm_dir={}",
config.slurm_dir.as_deref().unwrap_or("disabled")
);
info!("max_delta={}", config.max_delta); info!("max_delta={}", config.max_delta);
info!("strict_ccr_validation={}", config.strict_ccr_validation); info!("strict_ccr_validation={}", config.strict_ccr_validation);
info!( info!(
"refresh_interval_secs={}", "refresh_interval_secs={}",
config.refresh_interval.as_secs() config.refresh_interval.as_secs()
); );
info!( info!("max_connections={}", config.service_config.max_connections);
"max_connections={}",
config.service_config.max_connections
);
info!( info!(
"notify_queue_size={}", "notify_queue_size={}",
config.service_config.notify_queue_size config.service_config.notify_queue_size
@ -372,50 +391,3 @@ fn parse_bool(value: &str, name: &str) -> Result<bool> {
_ => Err(anyhow!("invalid {} '{}': expected boolean", name, value)), _ => Err(anyhow!("invalid {} '{}': expected boolean", name, value)),
} }
} }
fn load_payloads_from_latest_ccr(
ccr_dir: &str,
strict_ccr_validation: bool,
) -> Result<Vec<rpki::rtr::payload::Payload>> {
let latest = find_latest_ccr_file(ccr_dir)?;
let snapshot = load_ccr_snapshot_from_file(&latest)?;
let vrp_count = snapshot.vrps.len();
let vap_count = snapshot.vaps.len();
let produced_at = snapshot.produced_at.clone();
let conversion = load_ccr_payloads_from_file_with_options(&latest, strict_ccr_validation)?;
let payloads = conversion.payloads;
if !conversion.invalid_vrps.is_empty() {
warn!(
"CCR load skipped invalid VRPs: file={}, skipped={}, samples={:?}",
latest.display(),
conversion.invalid_vrps.len(),
sample_messages(&conversion.invalid_vrps)
);
}
if !conversion.invalid_vaps.is_empty() {
warn!(
"CCR load skipped invalid VAPs/ASPAs: file={}, skipped={}, samples={:?}",
latest.display(),
conversion.invalid_vaps.len(),
sample_messages(&conversion.invalid_vaps)
);
}
info!(
"loaded latest CCR snapshot: file={}, produced_at={:?}, vrp_count={}, vap_count={}, payload_count={}, strict_ccr_validation={}",
latest.display(),
produced_at,
vrp_count,
vap_count,
payloads.len(),
strict_ccr_validation
);
Ok(payloads)
}
fn sample_messages(messages: &[String]) -> Vec<&str> {
messages.iter().take(3).map(String::as_str).collect()
}

53
src/rtr/cache/core.rs vendored
View File

@ -1,14 +1,14 @@
use std::collections::{BTreeMap, VecDeque};
use std::cmp::Ordering;
use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BTreeMap, VecDeque};
use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::rtr::payload::{Payload, Timing}; use crate::rtr::payload::{Payload, Timing};
use super::model::{Delta, DualTime, Snapshot}; use super::model::{Delta, DualTime, Snapshot};
use super::ordering::{change_key, ChangeKey}; use super::ordering::{ChangeKey, change_key};
const SERIAL_HALF_RANGE: u32 = 1 << 31; const SERIAL_HALF_RANGE: u32 = 1 << 31;
@ -166,9 +166,7 @@ impl RtrCacheBuilder {
let serial = self.serial.unwrap_or(0); let serial = self.serial.unwrap_or(0);
let created_at = self.created_at.unwrap_or_else(|| now.clone()); let created_at = self.created_at.unwrap_or_else(|| now.clone());
let availability = self.availability.unwrap_or(CacheAvailability::Ready); let availability = self.availability.unwrap_or(CacheAvailability::Ready);
let session_ids = self let session_ids = self.session_ids.unwrap_or_else(SessionIds::random_distinct);
.session_ids
.unwrap_or_else(SessionIds::random_distinct);
RtrCache { RtrCache {
availability, availability,
@ -235,8 +233,7 @@ impl RtrCache {
self.serial = self.serial.wrapping_add(1); self.serial = self.serial.wrapping_add(1);
debug!( debug!(
"RTR cache advanced serial: old_serial={}, new_serial={}", "RTR cache advanced serial: old_serial={}, new_serial={}",
old, old, self.serial
self.serial
); );
self.serial self.serial
} }
@ -251,9 +248,7 @@ impl RtrCache {
let snapshot_wire_size = estimate_snapshot_payload_wire_size(&self.snapshot); let snapshot_wire_size = estimate_snapshot_payload_wire_size(&self.snapshot);
let mut cumulative_delta_wire_size = let mut cumulative_delta_wire_size =
estimate_delta_window_payload_wire_size(&self.deltas); estimate_delta_window_payload_wire_size(&self.deltas);
while !self.deltas.is_empty() while !self.deltas.is_empty() && cumulative_delta_wire_size >= snapshot_wire_size {
&& cumulative_delta_wire_size >= snapshot_wire_size
{
if let Some(oldest) = self.deltas.pop_front() { if let Some(oldest) = self.deltas.pop_front() {
dropped_serials.push(oldest.serial()); dropped_serials.push(oldest.serial());
cumulative_delta_wire_size = cumulative_delta_wire_size =
@ -262,9 +257,7 @@ impl RtrCache {
} }
debug!( debug!(
"RTR cache delta-size pruning evaluated: snapshot_wire_size={}, cumulative_delta_wire_size={}, dropped_serials={:?}", "RTR cache delta-size pruning evaluated: snapshot_wire_size={}, cumulative_delta_wire_size={}, dropped_serials={:?}",
snapshot_wire_size, snapshot_wire_size, cumulative_delta_wire_size, dropped_serials
cumulative_delta_wire_size,
dropped_serials
); );
} }
debug!( debug!(
@ -292,7 +285,10 @@ impl RtrCache {
} }
} }
pub(super) fn apply_update(&mut self, new_payloads: Vec<Payload>) -> Result<Option<AppliedUpdate>> { pub(super) fn apply_update(
&mut self,
new_payloads: Vec<Payload>,
) -> Result<Option<AppliedUpdate>> {
self.last_update_begin = DualTime::now(); self.last_update_begin = DualTime::now();
info!( info!(
"RTR cache applying update: availability={:?}, current_serial={}, incoming_payloads={}", "RTR cache applying update: availability={:?}, current_serial={}, incoming_payloads={}",
@ -319,14 +315,15 @@ impl RtrCache {
self.last_update_end = DualTime::now(); self.last_update_end = DualTime::now();
if !changed { if !changed {
debug!("RTR cache update produced empty snapshot but cache was already unavailable; no state change"); debug!(
"RTR cache update produced empty snapshot but cache was already unavailable; no state change"
);
return Ok(None); return Ok(None);
} }
info!( info!(
"RTR cache update cleared usable data and marked cache unavailable: serial={}, session_ids={:?}", "RTR cache update cleared usable data and marked cache unavailable: serial={}, session_ids={:?}",
self.serial, self.serial, self.session_ids
self.session_ids
); );
return Ok(Some(AppliedUpdate { return Ok(Some(AppliedUpdate {
@ -349,8 +346,7 @@ impl RtrCache {
self.last_update_end = DualTime::now(); self.last_update_end = DualTime::now();
debug!( debug!(
"RTR cache update detected identical snapshot content: serial={}, session_ids={:?}", "RTR cache update detected identical snapshot content: serial={}, session_ids={:?}",
self.serial, self.serial, self.session_ids
self.session_ids
); );
return Ok(None); return Ok(None);
} }
@ -455,8 +451,7 @@ impl RtrCache {
if client_serial == self.serial { if client_serial == self.serial {
debug!( debug!(
"RTR cache delta query is already up to date: client_serial={}, cache_serial={}", "RTR cache delta query is already up to date: client_serial={}, cache_serial={}",
client_serial, client_serial, self.serial
self.serial
); );
return SerialResult::UpToDate; return SerialResult::UpToDate;
} }
@ -467,8 +462,7 @@ impl RtrCache {
) { ) {
warn!( warn!(
"RTR cache delta query requires reset due to invalid/newer client serial: client_serial={}, cache_serial={}", "RTR cache delta query requires reset due to invalid/newer client serial: client_serial={}, cache_serial={}",
client_serial, client_serial, self.serial
self.serial
); );
return SerialResult::ResetRequired; return SerialResult::ResetRequired;
} }
@ -489,8 +483,7 @@ impl RtrCache {
if deltas.is_empty() { if deltas.is_empty() {
debug!( debug!(
"RTR cache delta query resolved to no deltas: client_serial={}, cache_serial={}", "RTR cache delta query resolved to no deltas: client_serial={}, cache_serial={}",
client_serial, client_serial, self.serial
self.serial
); );
return SerialResult::UpToDate; return SerialResult::UpToDate;
} }
@ -633,7 +626,11 @@ fn estimate_payload_wire_size(payload: &Payload, announce: bool) -> usize {
}, },
Payload::RouterKey(key) => 8 + 20 + 4 + key.spki().len(), Payload::RouterKey(key) => 8 + 20 + 4 + key.spki().len(),
Payload::Aspa(aspa) => { Payload::Aspa(aspa) => {
let providers = if announce { aspa.provider_asns().len() } else { 0 }; let providers = if announce {
aspa.provider_asns().len()
} else {
0
};
8 + 4 + providers * 4 8 + 4 + providers * 4
} }
} }

View File

@ -195,7 +195,12 @@ impl Snapshot {
} }
if !self.same_aspas(new_snapshot) { if !self.same_aspas(new_snapshot) {
diff_aspas(&self.aspas, &new_snapshot.aspas, &mut announced, &mut withdrawn); diff_aspas(
&self.aspas,
&new_snapshot.aspas,
&mut announced,
&mut withdrawn,
);
} }
(announced, withdrawn) (announced, withdrawn)
@ -206,9 +211,8 @@ impl Snapshot {
} }
pub fn payloads(&self) -> Vec<Payload> { pub fn payloads(&self) -> Vec<Payload> {
let mut v = Vec::with_capacity( let mut v =
self.origins.len() + self.router_keys.len() + self.aspas.len(), Vec::with_capacity(self.origins.len() + self.router_keys.len() + self.aspas.len());
);
v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin)); v.extend(self.origins.iter().cloned().map(Payload::RouteOrigin));
v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey)); v.extend(self.router_keys.iter().cloned().map(Payload::RouterKey));
@ -268,9 +272,7 @@ impl Snapshot {
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.origins.is_empty() self.origins.is_empty() && self.router_keys.is_empty() && self.aspas.is_empty()
&& self.router_keys.is_empty()
&& self.aspas.is_empty()
} }
} }

View File

@ -64,8 +64,18 @@ enum PayloadPduType {
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) enum RouteOriginKey { pub(crate) enum RouteOriginKey {
V4 { addr: u32, plen: u8, mlen: u8, asn: u32 }, V4 {
V6 { addr: u128, plen: u8, mlen: u8, asn: u32 }, addr: u32,
plen: u8,
mlen: u8,
asn: u32,
},
V6 {
addr: u128,
plen: u8,
mlen: u8,
asn: u32,
},
} }
pub(crate) fn change_key(payload: &Payload) -> ChangeKey { pub(crate) fn change_key(payload: &Payload) -> ChangeKey {
@ -287,7 +297,11 @@ fn payload_brief(payload: &Payload) -> String {
match payload { match payload {
Payload::RouteOrigin(origin) => format!( Payload::RouteOrigin(origin) => format!(
"{} prefix {:?}/{} max={} asn={}", "{} prefix {:?}/{} max={} asn={}",
if route_origin_is_ipv4(origin) { "IPv4" } else { "IPv6" }, if route_origin_is_ipv4(origin) {
"IPv4"
} else {
"IPv6"
},
origin.prefix().address, origin.prefix().address,
origin.prefix().prefix_length, origin.prefix().prefix_length,
origin.max_length(), origin.max_length(),

View File

@ -18,12 +18,9 @@ impl RtrCache {
timing: Timing, timing: Timing,
file_loader: impl Fn() -> Result<Vec<Payload>>, file_loader: impl Fn() -> Result<Vec<Payload>>,
) -> Result<Self> { ) -> Result<Self> {
if let Some(cache) = try_restore_from_store( if let Some(cache) =
store, try_restore_from_store(store, max_delta, prune_delta_by_snapshot_size, timing)?
max_delta, {
prune_delta_by_snapshot_size,
timing,
)? {
tracing::info!( tracing::info!(
"RTR cache restored from store: availability={:?}, session_ids={:?}, serial={}, snapshot(route_origins={}, router_keys={}, aspas={})", "RTR cache restored from store: availability={:?}, session_ids={:?}, serial={}, snapshot(route_origins={}, router_keys={}, aspas={})",
cache.availability(), cache.availability(),

View File

@ -19,7 +19,6 @@ pub enum ErrorCode {
} }
impl ErrorCode { impl ErrorCode {
#[inline] #[inline]
pub fn as_u16(self) -> u16 { pub fn as_u16(self) -> u16 {
self as u16 self as u16
@ -27,41 +26,29 @@ impl ErrorCode {
pub fn description(self) -> &'static str { pub fn description(self) -> &'static str {
match self { match self {
ErrorCode::CorruptData => ErrorCode::CorruptData => "Corrupt Data",
"Corrupt Data",
ErrorCode::InternalError => ErrorCode::InternalError => "Internal Error",
"Internal Error",
ErrorCode::NoDataAvailable => ErrorCode::NoDataAvailable => "No Data Available",
"No Data Available",
ErrorCode::InvalidRequest => ErrorCode::InvalidRequest => "Invalid Request",
"Invalid Request",
ErrorCode::UnsupportedProtocolVersion => ErrorCode::UnsupportedProtocolVersion => "Unsupported Protocol Version",
"Unsupported Protocol Version",
ErrorCode::UnsupportedPduType => ErrorCode::UnsupportedPduType => "Unsupported PDU Type",
"Unsupported PDU Type",
ErrorCode::WithdrawalOfUnknownRecord => ErrorCode::WithdrawalOfUnknownRecord => "Withdrawal of Unknown Record",
"Withdrawal of Unknown Record",
ErrorCode::DuplicateAnnouncement => ErrorCode::DuplicateAnnouncement => "Duplicate Announcement Received",
"Duplicate Announcement Received",
ErrorCode::UnexpectedProtocolVersion => ErrorCode::UnexpectedProtocolVersion => "Unexpected Protocol Version",
"Unexpected Protocol Version",
ErrorCode::AspaProviderListError => ErrorCode::AspaProviderListError => "ASPA Provider List Error",
"ASPA Provider List Error",
ErrorCode::TransportFailed => ErrorCode::TransportFailed => "Transport Failed",
"Transport Failed",
ErrorCode::OrderingError => ErrorCode::OrderingError => "Ordering Error",
"Ordering Error",
} }
} }
} }
@ -90,9 +77,6 @@ impl TryFrom<u16> for ErrorCode {
impl fmt::Display for ErrorCode { impl fmt::Display for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} ({})", write!(f, "{} ({})", self.description(), *self as u16)
self.description(),
*self as u16
)
} }
} }

View File

@ -3,7 +3,7 @@ use std::net::IpAddr;
use std::path::Path; use std::path::Path;
use std::str::FromStr; use std::str::FromStr;
use anyhow::{anyhow, Context, Result}; use anyhow::{Context, Result, anyhow};
use crate::data_model::resources::as_resources::Asn; use crate::data_model::resources::as_resources::Asn;
use crate::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; use crate::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
@ -131,10 +131,9 @@ pub fn parse_vrp_line(line: &str) -> Result<ParsedVrp> {
} }
let prefix_part = parts[0]; let prefix_part = parts[0];
let max_len = u8::from_str(parts[1]) let max_len =
.with_context(|| format!("invalid max_len: {}", parts[1]))?; u8::from_str(parts[1]).with_context(|| format!("invalid max_len: {}", parts[1]))?;
let asn = u32::from_str(parts[2]) let asn = u32::from_str(parts[2]).with_context(|| format!("invalid asn: {}", parts[2]))?;
.with_context(|| format!("invalid asn: {}", parts[2]))?;
let (addr_str, prefix_len_str) = prefix_part let (addr_str, prefix_len_str) = prefix_part
.split_once('/') .split_once('/')
@ -164,14 +163,13 @@ pub fn parse_aspa_line(line: &str) -> Result<ParsedAspa> {
)); ));
} }
let customer_asn = u32::from_str(parts[0]) let customer_asn =
.with_context(|| format!("invalid customer_asn: {}", parts[0]))?; u32::from_str(parts[0]).with_context(|| format!("invalid customer_asn: {}", parts[0]))?;
let provider_asns = parts[1] let provider_asns = parts[1]
.split_whitespace() .split_whitespace()
.map(|provider| { .map(|provider| {
u32::from_str(provider) u32::from_str(provider).with_context(|| format!("invalid provider_asn: {}", provider))
.with_context(|| format!("invalid provider_asn: {}", provider))
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -186,23 +184,18 @@ pub fn parse_aspa_line(line: &str) -> Result<ParsedAspa> {
pub fn parse_router_key_line(line: &str) -> Result<ParsedRouterKey> { pub fn parse_router_key_line(line: &str) -> Result<ParsedRouterKey> {
let parts: Vec<_> = line.split(',').map(|s| s.trim()).collect(); let parts: Vec<_> = line.split(',').map(|s| s.trim()).collect();
if parts.len() != 3 { if parts.len() != 3 {
return Err(anyhow!( return Err(anyhow!("expected format: <ski_hex>,<asn>,<spki_hex>"));
"expected format: <ski_hex>,<asn>,<spki_hex>"
));
} }
let ski_vec = decode_hex(parts[0]) let ski_vec = decode_hex(parts[0]).with_context(|| format!("invalid SKI hex: {}", parts[0]))?;
.with_context(|| format!("invalid SKI hex: {}", parts[0]))?;
if ski_vec.len() != 20 { if ski_vec.len() != 20 {
return Err(anyhow!("SKI must be exactly 20 bytes")); return Err(anyhow!("SKI must be exactly 20 bytes"));
} }
let mut ski = [0u8; 20]; let mut ski = [0u8; 20];
ski.copy_from_slice(&ski_vec); ski.copy_from_slice(&ski_vec);
let asn = u32::from_str(parts[1]) let asn = u32::from_str(parts[1]).with_context(|| format!("invalid asn: {}", parts[1]))?;
.with_context(|| format!("invalid asn: {}", parts[1]))?; let spki = decode_hex(parts[2]).with_context(|| format!("invalid SPKI hex: {}", parts[2]))?;
let spki = decode_hex(parts[2])
.with_context(|| format!("invalid SPKI hex: {}", parts[2]))?;
validate_router_key(asn, &spki)?; validate_router_key(asn, &spki)?;
@ -254,11 +247,7 @@ fn validate_aspa(customer_asn: u32, provider_asns: &[u32]) -> Result<()> {
} }
fn validate_router_key(asn: u32, spki: &[u8]) -> Result<()> { fn validate_router_key(asn: u32, spki: &[u8]) -> Result<()> {
crate::rtr::payload::RouterKey::new( crate::rtr::payload::RouterKey::new(Ski::default(), Asn::from(asn), spki.to_vec())
Ski::default(),
Asn::from(asn),
spki.to_vec(),
)
.validate() .validate()
.map_err(|err| anyhow!(err.to_string()))?; .map_err(|err| anyhow!(err.to_string()))?;
Ok(()) Ok(())
@ -309,4 +298,3 @@ fn decode_hex(input: &str) -> Result<Vec<u8>> {
}) })
.collect() .collect()
} }

View File

@ -1,10 +1,9 @@
pub mod pdu;
pub mod cache; pub mod cache;
pub mod payload;
pub mod store;
pub mod session;
pub mod error_type; pub mod error_type;
pub mod state;
pub mod server;
pub mod loader; pub mod loader;
pub mod ccr; pub mod payload;
pub mod pdu;
pub mod server;
pub mod session;
pub mod state;
pub mod store;

View File

@ -1,13 +1,12 @@
use crate::data_model::resources::as_resources::Asn;
use crate::data_model::resources::ip_resources::IPAddressPrefix;
use serde::{Deserialize, Serialize};
use std::fmt::Debug; use std::fmt::Debug;
use std::io; use std::io;
use std::time::Duration; use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::data_model::resources::as_resources::Asn;
use crate::data_model::resources::ip_resources::IPAddressPrefix;
use x509_parser::prelude::FromDer; use x509_parser::prelude::FromDer;
use x509_parser::x509::SubjectPublicKeyInfo; use x509_parser::x509::SubjectPublicKeyInfo;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)] #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
enum PayloadPduType { enum PayloadPduType {
Ipv4Prefix = 4, Ipv4Prefix = 4,
@ -16,7 +15,9 @@ enum PayloadPduType {
Aspa = 11, Aspa = 11,
} }
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] #[derive(
Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
)]
pub struct Ski([u8; 20]); pub struct Ski([u8; 20]);
impl AsRef<[u8]> for Ski { impl AsRef<[u8]> for Ski {
@ -60,7 +61,6 @@ impl RouteOrigin {
} }
} }
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct RouterKey { pub struct RouterKey {
subject_key_identifier: Ski, subject_key_identifier: Ski,
@ -104,8 +104,8 @@ impl RouterKey {
)); ));
} }
let (rem, _) = SubjectPublicKeyInfo::from_der(&self.subject_public_key_info) let (rem, _) =
.map_err(|err| { SubjectPublicKeyInfo::from_der(&self.subject_public_key_info).map_err(|err| {
io::Error::new( io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
format!("RouterKey SPKI is not valid DER: {err}"), format!("RouterKey SPKI is not valid DER: {err}"),
@ -115,10 +115,7 @@ impl RouterKey {
if !rem.is_empty() { if !rem.is_empty() {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
format!( format!("RouterKey SPKI DER has trailing bytes: {}", rem.len()),
"RouterKey SPKI DER has trailing bytes: {}",
rem.len()
),
)); ));
} }
@ -177,7 +174,6 @@ impl Aspa {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum Payload { pub enum Payload {
@ -191,7 +187,6 @@ pub enum Payload {
Aspa(Aspa), Aspa(Aspa),
} }
// Timing // Timing
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct Timing { pub struct Timing {
@ -202,7 +197,7 @@ pub struct Timing {
pub retry: u32, pub retry: u32,
/// The number of secionds before data expires if not refreshed. /// The number of secionds before data expires if not refreshed.
pub expire: u32 pub expire: u32,
} }
impl Timing { impl Timing {
@ -214,7 +209,11 @@ impl Timing {
pub const MAX_EXPIRE: u32 = 172_800; pub const MAX_EXPIRE: u32 = 172_800;
pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self { pub const fn new(refresh: u32, retry: u32, expire: u32) -> Self {
Self { refresh, retry, expire } Self {
refresh,
retry,
expire,
}
} }
pub fn validate(self) -> Result<(), io::Error> { pub fn validate(self) -> Result<(), io::Error> {
@ -223,7 +222,9 @@ impl Timing {
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
format!( format!(
"refresh interval {} out of range {}..={}", "refresh interval {} out of range {}..={}",
self.refresh, Self::MIN_REFRESH, Self::MAX_REFRESH self.refresh,
Self::MIN_REFRESH,
Self::MAX_REFRESH
), ),
)); ));
} }
@ -233,7 +234,9 @@ impl Timing {
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
format!( format!(
"retry interval {} out of range {}..={}", "retry interval {} out of range {}..={}",
self.retry, Self::MIN_RETRY, Self::MAX_RETRY self.retry,
Self::MIN_RETRY,
Self::MAX_RETRY
), ),
)); ));
} }
@ -243,7 +246,9 @@ impl Timing {
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
format!( format!(
"expire interval {} out of range {}..={}", "expire interval {} out of range {}..={}",
self.expire, Self::MIN_EXPIRE, Self::MAX_EXPIRE self.expire,
Self::MIN_EXPIRE,
Self::MAX_EXPIRE
), ),
)); ));
} }
@ -282,7 +287,6 @@ impl Timing {
pub fn expire(self) -> Duration { pub fn expire(self) -> Duration {
Duration::from_secs(u64::from(self.expire)) Duration::from_secs(u64::from(self.expire))
} }
} }
impl Default for Timing { impl Default for Timing {

View File

@ -1,16 +1,16 @@
use std::{cmp, mem};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use crate::data_model::resources::as_resources::Asn; use crate::data_model::resources::as_resources::Asn;
use crate::rtr::error_type::ErrorCode; use crate::rtr::error_type::ErrorCode;
use crate::rtr::payload::{Ski, Timing}; use crate::rtr::payload::{Ski, Timing};
use std::io;
use tokio::io::{AsyncWrite};
use anyhow::Result; use anyhow::Result;
use std::io;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use std::{cmp, mem};
use tokio::io::AsyncWrite;
use std::slice;
use anyhow::bail; use anyhow::bail;
use serde::Serialize; use serde::Serialize;
use std::slice;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
pub const HEADER_LEN: usize = 8; pub const HEADER_LEN: usize = 8;
@ -27,10 +27,7 @@ macro_rules! common {
#[allow(dead_code)] #[allow(dead_code)]
impl $type { impl $type {
/// Writes a value to a writer. /// Writes a value to a writer.
pub async fn write<A: AsyncWrite + Unpin>( pub async fn write<A: AsyncWrite + Unpin>(&self, a: &mut A) -> Result<(), io::Error> {
&self,
a: &mut A
) -> Result<(), io::Error> {
a.write_all(self.as_ref()).await a.write_all(self.as_ref()).await
} }
} }
@ -38,10 +35,7 @@ macro_rules! common {
impl AsRef<[u8]> for $type { impl AsRef<[u8]> for $type {
fn as_ref(&self) -> &[u8] { fn as_ref(&self) -> &[u8] {
unsafe { unsafe {
slice::from_raw_parts( slice::from_raw_parts(self as *const Self as *const u8, mem::size_of::<Self>())
self as *const Self as *const u8,
mem::size_of::<Self>()
)
} }
} }
} }
@ -49,14 +43,11 @@ macro_rules! common {
impl AsMut<[u8]> for $type { impl AsMut<[u8]> for $type {
fn as_mut(&mut self) -> &mut [u8] { fn as_mut(&mut self) -> &mut [u8] {
unsafe { unsafe {
slice::from_raw_parts_mut( slice::from_raw_parts_mut(self as *mut Self as *mut u8, mem::size_of::<Self>())
self as *mut Self as *mut u8,
mem::size_of::<Self>()
)
}
} }
} }
} }
};
} }
macro_rules! concrete { macro_rules! concrete {
@ -94,28 +85,20 @@ macro_rules! concrete {
/// ///
/// If a value with a different PDU type is received, returns an /// If a value with a different PDU type is received, returns an
/// error. /// error.
pub async fn read<Sock: AsyncRead + Unpin>( pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
sock: &mut Sock
) -> Result<Self, io::Error> {
let mut res = Self::default(); let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?; sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() != Self::PDU { if res.header.pdu() != Self::PDU {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
concat!( concat!("PDU type mismatch when expecting ", stringify!($type)),
"PDU type mismatch when expecting ", ));
stringify!($type)
)
))
} }
if res.header.length() as usize != res.as_ref().len() { if res.header.length() as usize != res.as_ref().len() {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
concat!( concat!("invalid length for ", stringify!($type)),
"invalid length for ", ));
stringify!($type)
)
))
} }
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
Ok(res) Ok(res)
@ -126,32 +109,26 @@ macro_rules! concrete {
/// If a different PDU type is received, returns the header as /// If a different PDU type is received, returns the header as
/// the error case of the ok case. /// the error case of the ok case.
pub async fn try_read<Sock: AsyncRead + Unpin>( pub async fn try_read<Sock: AsyncRead + Unpin>(
sock: &mut Sock sock: &mut Sock,
) -> Result<Result<Self, Header>, io::Error> { ) -> Result<Result<Self, Header>, io::Error> {
let mut res = Self::default(); let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?; sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() == ErrorReport::PDU { if res.header.pdu() == ErrorReport::PDU {
// Since we should drop the session after an error, we // Since we should drop the session after an error, we
// can safely ignore all the rest of the error for now. // can safely ignore all the rest of the error for now.
return Ok(Err(res.header)) return Ok(Err(res.header));
} }
if res.header.pdu() != Self::PDU { if res.header.pdu() != Self::PDU {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
concat!( concat!("PDU type mismatch when expecting ", stringify!($type)),
"PDU type mismatch when expecting ", ));
stringify!($type)
)
))
} }
if res.header.length() as usize != res.as_ref().len() { if res.header.length() as usize != res.as_ref().len() {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
concat!( concat!("invalid length for ", stringify!($type)),
"invalid length for ", ));
stringify!($type)
)
))
} }
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
Ok(Ok(res)) Ok(Ok(res))
@ -163,17 +140,14 @@ macro_rules! concrete {
/// `header`, the function reads the rest of the PUD from the /// `header`, the function reads the rest of the PUD from the
/// reader and returns the complete value. /// reader and returns the complete value.
pub async fn read_payload<Sock: AsyncRead + Unpin>( pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: Header, sock: &mut Sock header: Header,
sock: &mut Sock,
) -> Result<Self, io::Error> { ) -> Result<Self, io::Error> {
if header.length() as usize != mem::size_of::<Self>() { if header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
concat!( concat!("invalid length for ", stringify!($type), " PDU"),
"invalid length for ", ));
stringify!($type),
" PDU"
)
))
} }
let mut res = Self::default(); let mut res = Self::default();
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
@ -181,9 +155,8 @@ macro_rules! concrete {
Ok(res) Ok(res)
} }
} }
};
} }
}
// 所有PDU公共头部信息 // 所有PDU公共头部信息
#[repr(C, packed)] #[repr(C, packed)]
@ -196,7 +169,6 @@ pub struct Header {
} }
impl Header { impl Header {
const LEN: usize = mem::size_of::<Self>(); const LEN: usize = mem::size_of::<Self>();
pub fn new(version: u8, pdu: u8, session: u16, length: u32) -> Self { pub fn new(version: u8, pdu: u8, session: u16, length: u32) -> Self {
Header { Header {
@ -208,7 +180,7 @@ impl Header {
} }
pub async fn read_raw<S: AsyncRead + Unpin>( pub async fn read_raw<S: AsyncRead + Unpin>(
sock: &mut S sock: &mut S,
) -> Result<[u8; HEADER_LEN], io::Error> { ) -> Result<[u8; HEADER_LEN], io::Error> {
let mut buf = [0u8; HEADER_LEN]; let mut buf = [0u8; HEADER_LEN];
sock.read_exact(&mut buf).await?; sock.read_exact(&mut buf).await?;
@ -229,10 +201,7 @@ impl Header {
} }
if length > MAX_PDU_LEN { if length > MAX_PDU_LEN {
return Err(io::Error::new( return Err(io::Error::new(io::ErrorKind::InvalidData, "PDU too large"));
io::ErrorKind::InvalidData,
"PDU too large",
));
} }
Ok(Self { Ok(Self {
@ -247,13 +216,21 @@ impl Header {
Self::from_raw(Self::read_raw(sock).await?) Self::from_raw(Self::read_raw(sock).await?)
} }
pub fn version(self) -> u8{self.version} pub fn version(self) -> u8 {
self.version
}
pub fn pdu(self) -> u8{self.pdu} pub fn pdu(self) -> u8 {
self.pdu
}
pub fn session_id(self) -> u16{u16::from_be(self.session_id)} pub fn session_id(self) -> u16 {
u16::from_be(self.session_id)
}
pub fn length(self) -> u32{u32::from_be(self.length)} pub fn length(self) -> u32 {
u32::from_be(self.length)
}
pub fn pdu_len(self) -> Result<usize, io::Error> { pub fn pdu_len(self) -> Result<usize, io::Error> {
usize::try_from(self.length()).map_err(|_| { usize::try_from(self.length()).map_err(|_| {
@ -268,7 +245,6 @@ impl Header {
debug_assert_eq!(self.pdu(), ErrorReport::PDU); debug_assert_eq!(self.pdu(), ErrorReport::PDU);
self.session_id() self.session_id()
} }
} }
common!(Header); common!(Header);
@ -304,9 +280,7 @@ impl HeaderWithFlags {
let pdu = buf[1]; let pdu = buf[1];
let flags = buf[2]; let flags = buf[2];
let zero = buf[3]; let zero = buf[3];
let length = u32::from_be_bytes([ let length = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]);
buf[4], buf[5], buf[6], buf[7],
]);
// 3. 基础合法性校验 // 3. 基础合法性校验
if length < HEADER_LEN as u32 { if length < HEADER_LEN as u32 {
@ -327,17 +301,26 @@ impl HeaderWithFlags {
}) })
} }
pub fn version(self) -> u8{self.version} pub fn version(self) -> u8 {
self.version
pub fn pdu(self) -> u8{self.pdu}
pub fn flags(self) -> Flags{Flags(self.flags)}
pub fn zero(self) -> u8 { self.zero }
pub fn length(self) -> u32{u32::from_be(self.length)}
} }
pub fn pdu(self) -> u8 {
self.pdu
}
pub fn flags(self) -> Flags {
Flags(self.flags)
}
pub fn zero(self) -> u8 {
self.zero
}
pub fn length(self) -> u32 {
u32::from_be(self.length)
}
}
// Serial Notify // Serial Notify
#[repr(C, packed)] #[repr(C, packed)]
@ -353,18 +336,16 @@ impl SerialNotify {
pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self {
SerialNotify { SerialNotify {
header: Header::new(version, Self::PDU, session_id, Self::size()), header: Header::new(version, Self::PDU, session_id, Self::size()),
serial_number: serial_number.to_be() serial_number: serial_number.to_be(),
} }
} }
pub fn serial_number(self) -> u32 { pub fn serial_number(self) -> u32 {
u32::from_be(self.serial_number) u32::from_be(self.serial_number)
} }
} }
concrete!(SerialNotify); concrete!(SerialNotify);
// Serial Query // Serial Query
#[repr(C, packed)] #[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
@ -379,7 +360,7 @@ impl SerialQuery {
pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self { pub fn new(version: u8, session_id: u16, serial_number: u32) -> Self {
SerialQuery { SerialQuery {
header: Header::new(version, Self::PDU, session_id, Self::size()), header: Header::new(version, Self::PDU, session_id, Self::size()),
serial_number: serial_number.to_be() serial_number: serial_number.to_be(),
} }
} }
@ -390,12 +371,11 @@ impl SerialQuery {
concrete!(SerialQuery); concrete!(SerialQuery);
// Reset Query // Reset Query
#[repr(C, packed)] #[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct ResetQuery { pub struct ResetQuery {
header: Header header: Header,
} }
impl ResetQuery { impl ResetQuery {
@ -410,7 +390,6 @@ impl ResetQuery {
concrete!(ResetQuery); concrete!(ResetQuery);
// Cache Response // Cache Response
#[repr(C, packed)] #[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
@ -430,7 +409,6 @@ impl CacheResponse {
concrete!(CacheResponse); concrete!(CacheResponse);
// Flags // Flags
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Flags(u8); pub struct Flags(u8);
@ -464,7 +442,7 @@ pub struct IPv4Prefix {
max_len: u8, max_len: u8,
zero: u8, zero: u8,
prefix: u32, prefix: u32,
asn: u32 asn: u32,
} }
impl IPv4Prefix { impl IPv4Prefix {
@ -475,7 +453,7 @@ impl IPv4Prefix {
prefix_len: u8, prefix_len: u8,
max_len: u8, max_len: u8,
prefix: Ipv4Addr, prefix: Ipv4Addr,
asn: Asn asn: Asn,
) -> Self { ) -> Self {
IPv4Prefix { IPv4Prefix {
header: Header::new(version, Self::PDU, ZERO_16, IPV4_PREFIX_LEN), header: Header::new(version, Self::PDU, ZERO_16, IPV4_PREFIX_LEN),
@ -488,12 +466,22 @@ impl IPv4Prefix {
} }
} }
pub fn flag(self) -> Flags{self.flags} pub fn flag(self) -> Flags {
self.flags
}
pub fn prefix_len(self) -> u8{self.prefix_len} pub fn prefix_len(self) -> u8 {
pub fn max_len(self) -> u8{self.max_len} self.prefix_len
pub fn prefix(self) -> Ipv4Addr{u32::from_be(self.prefix).into()} }
pub fn asn(self) -> Asn{u32::from_be(self.asn).into()} pub fn max_len(self) -> u8 {
self.max_len
}
pub fn prefix(self) -> Ipv4Addr {
u32::from_be(self.prefix).into()
}
pub fn asn(self) -> Asn {
u32::from_be(self.asn).into()
}
} }
concrete!(IPv4Prefix); concrete!(IPv4Prefix);
@ -509,7 +497,7 @@ pub struct IPv6Prefix {
max_len: u8, max_len: u8,
zero: u8, zero: u8,
prefix: u128, prefix: u128,
asn: u32 asn: u32,
} }
impl IPv6Prefix { impl IPv6Prefix {
@ -520,7 +508,7 @@ impl IPv6Prefix {
prefix_len: u8, prefix_len: u8,
max_len: u8, max_len: u8,
prefix: Ipv6Addr, prefix: Ipv6Addr,
asn: Asn asn: Asn,
) -> Self { ) -> Self {
IPv6Prefix { IPv6Prefix {
header: Header::new(version, Self::PDU, ZERO_16, IPV6_PREFIX_LEN), header: Header::new(version, Self::PDU, ZERO_16, IPV6_PREFIX_LEN),
@ -533,17 +521,26 @@ impl IPv6Prefix {
} }
} }
pub fn flag(self) -> Flags{self.flags} pub fn flag(self) -> Flags {
self.flags
}
pub fn prefix_len(self) -> u8{self.prefix_len} pub fn prefix_len(self) -> u8 {
pub fn max_len(self) -> u8{self.max_len} self.prefix_len
pub fn prefix(self) -> Ipv6Addr{u128::from_be(self.prefix).into()} }
pub fn asn(self) -> Asn{u32::from_be(self.asn).into()} pub fn max_len(self) -> u8 {
self.max_len
}
pub fn prefix(self) -> Ipv6Addr {
u128::from_be(self.prefix).into()
}
pub fn asn(self) -> Asn {
u32::from_be(self.asn).into()
}
} }
concrete!(IPv6Prefix); concrete!(IPv6Prefix);
// End of Data // End of Data
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize)]
pub enum EndOfData { pub enum EndOfData {
@ -559,14 +556,20 @@ impl EndOfData {
timing: Timing, timing: Timing,
) -> Result<Self, io::Error> { ) -> Result<Self, io::Error> {
if version == 0 { if version == 0 {
Ok(EndOfData::V0(EndOfDataV0::new(version, session_id, serial_number))) Ok(EndOfData::V0(EndOfDataV0::new(
} version,
else { session_id,
Ok(EndOfData::V1(EndOfDataV1::new(version, session_id, serial_number, timing)?)) serial_number,
)))
} else {
Ok(EndOfData::V1(EndOfDataV1::new(
version,
session_id,
serial_number,
timing,
)?))
} }
} }
} }
#[repr(C, packed)] #[repr(C, packed)]
@ -587,11 +590,12 @@ impl EndOfDataV0 {
} }
} }
pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} pub fn serial_number(self) -> u32 {
u32::from_be(self.serial_number)
}
} }
concrete!(EndOfDataV0); concrete!(EndOfDataV0);
#[repr(C, packed)] #[repr(C, packed)]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct EndOfDataV1 { pub struct EndOfDataV1 {
@ -601,7 +605,6 @@ pub struct EndOfDataV1 {
refresh_interval: u32, refresh_interval: u32,
retry_interval: u32, retry_interval: u32,
expire_interval: u32, expire_interval: u32,
} }
impl EndOfDataV1 { impl EndOfDataV1 {
@ -640,7 +643,9 @@ impl EndOfDataV1 {
}) })
} }
pub fn serial_number(self) -> u32{u32::from_be(self.serial_number)} pub fn serial_number(self) -> u32 {
u32::from_be(self.serial_number)
}
pub fn timing(self) -> Timing { pub fn timing(self) -> Timing {
Timing { Timing {
@ -654,22 +659,20 @@ impl EndOfDataV1 {
self.timing().validate() self.timing().validate()
} }
pub async fn read<Sock: AsyncRead + Unpin>( pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
sock: &mut Sock
) -> Result<Self, io::Error> {
let mut res = Self::default(); let mut res = Self::default();
sock.read_exact(res.header.as_mut()).await?; sock.read_exact(res.header.as_mut()).await?;
if res.header.pdu() != Self::PDU { if res.header.pdu() != Self::PDU {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"PDU type mismatch when expecting EndOfDataV1", "PDU type mismatch when expecting EndOfDataV1",
)) ));
} }
if res.header.length() as usize != mem::size_of::<Self>() { if res.header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"invalid length for EndOfDataV1", "invalid length for EndOfDataV1",
)) ));
} }
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
res.validate()?; res.validate()?;
@ -677,13 +680,14 @@ impl EndOfDataV1 {
} }
pub async fn read_payload<Sock: AsyncRead + Unpin>( pub async fn read_payload<Sock: AsyncRead + Unpin>(
header: Header, sock: &mut Sock header: Header,
sock: &mut Sock,
) -> Result<Self, io::Error> { ) -> Result<Self, io::Error> {
if header.length() as usize != mem::size_of::<Self>() { if header.length() as usize != mem::size_of::<Self>() {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"invalid length for EndOfDataV1 PDU", "invalid length for EndOfDataV1 PDU",
)) ));
} }
let mut res = Self::default(); let mut res = Self::default();
sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?; sock.read_exact(&mut res.as_mut()[Header::LEN..]).await?;
@ -706,21 +710,19 @@ impl CacheReset {
pub fn new(version: u8) -> Self { pub fn new(version: u8) -> Self {
CacheReset { CacheReset {
header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32) header: Header::new(version, Self::PDU, ZERO_16, HEADER_LEN as u32),
} }
} }
} }
concrete!(CacheReset); concrete!(CacheReset);
// Error Report // Error Report
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct ErrorReport { pub struct ErrorReport {
octets: Vec<u8>, octets: Vec<u8>,
} }
impl ErrorReport { impl ErrorReport {
/// The PDU type of an error PDU. /// The PDU type of an error PDU.
pub const PDU: u8 = 10; pub const PDU: u8 = 10;
@ -741,27 +743,19 @@ impl ErrorReport {
let text_len = cmp::min(text.len(), text_room); let text_len = cmp::min(text.len(), text_room);
let size = Self::FIXED_PART_LEN + pdu_len + text_len; let size = Self::FIXED_PART_LEN + pdu_len + text_len;
let header = Header::new( let header = Header::new(version, 10, error_code, u32::try_from(size).unwrap());
version, 10, error_code, u32::try_from(size).unwrap()
);
let mut octets = Vec::with_capacity(size); let mut octets = Vec::with_capacity(size);
octets.extend_from_slice(header.as_ref()); octets.extend_from_slice(header.as_ref());
octets.extend_from_slice( octets.extend_from_slice(u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref());
u32::try_from(pdu_len).unwrap().to_be_bytes().as_ref()
);
octets.extend_from_slice(&pdu[..pdu_len]); octets.extend_from_slice(&pdu[..pdu_len]);
octets.extend_from_slice( octets.extend_from_slice(u32::try_from(text_len).unwrap().to_be_bytes().as_ref());
u32::try_from(text_len).unwrap().to_be_bytes().as_ref()
);
octets.extend_from_slice(&text[..text_len]); octets.extend_from_slice(&text[..text_len]);
ErrorReport { octets } ErrorReport { octets }
} }
pub async fn read<Sock: AsyncRead + Unpin>( pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
sock: &mut Sock
) -> Result<Self, io::Error> {
let header = Header::read(sock).await?; let header = Header::read(sock).await?;
if header.pdu() != Self::PDU { if header.pdu() != Self::PDU {
return Err(io::Error::new( return Err(io::Error::new(
@ -787,7 +781,8 @@ impl ErrorReport {
let mut octets = Vec::with_capacity(total_len); let mut octets = Vec::with_capacity(total_len);
octets.extend_from_slice(header.as_ref()); octets.extend_from_slice(header.as_ref());
octets.resize(total_len, 0); octets.resize(total_len, 0);
sock.read_exact(&mut octets[mem::size_of::<Header>()..]).await?; sock.read_exact(&mut octets[mem::size_of::<Header>()..])
.await?;
let res = ErrorReport { octets }; let res = ErrorReport { octets };
res.validate()?; res.validate()?;
@ -813,7 +808,8 @@ impl ErrorReport {
/// Skips over the payload of the error PDU. /// Skips over the payload of the error PDU.
pub async fn skip_payload<Sock: AsyncRead + Unpin>( pub async fn skip_payload<Sock: AsyncRead + Unpin>(
header: Header, sock: &mut Sock header: Header,
sock: &mut Sock,
) -> Result<(), io::Error> { ) -> Result<(), io::Error> {
let Some(mut remaining) = header.pdu_len()?.checked_sub(mem::size_of::<Header>()) else { let Some(mut remaining) = header.pdu_len()?.checked_sub(mem::size_of::<Header>()) else {
return Err(io::Error::new( return Err(io::Error::new(
@ -840,9 +836,7 @@ impl ErrorReport {
} }
/// Writes the PUD to a writer. /// Writes the PUD to a writer.
pub async fn write<A: AsyncWrite + Unpin>( pub async fn write<A: AsyncWrite + Unpin>(&self, a: &mut A) -> Result<(), io::Error> {
&self, a: &mut A
) -> Result<(), io::Error> {
a.write_all(self.as_ref()).await a.write_all(self.as_ref()).await
} }
@ -860,7 +854,7 @@ impl ErrorReport {
u32::from_be_bytes( u32::from_be_bytes(
self.octets[Header::LEN..Header::LEN + 4] self.octets[Header::LEN..Header::LEN + 4]
.try_into() .try_into()
.unwrap() .unwrap(),
) as usize ) as usize
} }
@ -876,11 +870,7 @@ impl ErrorReport {
fn text_len(&self) -> usize { fn text_len(&self) -> usize {
let offset = self.text_len_offset(); let offset = self.text_len_offset();
u32::from_be_bytes( u32::from_be_bytes(self.octets[offset..offset + 4].try_into().unwrap()) as usize
self.octets[offset..offset + 4]
.try_into()
.unwrap()
) as usize
} }
fn text_range(&self) -> std::ops::Range<usize> { fn text_range(&self) -> std::ops::Range<usize> {
@ -916,7 +906,10 @@ impl ErrorReport {
let pdu_len = self.erroneous_pdu_len(); let pdu_len = self.erroneous_pdu_len();
let text_len_offset = Header::LEN + 4 + pdu_len; let text_len_offset = Header::LEN + 4 + pdu_len;
let Some(text_len_end) = text_len_offset.checked_add(4) else { let Some(text_len_end) = text_len_offset.checked_add(4) else {
return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport length overflow")); return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport length overflow",
));
}; };
if text_len_end > self.octets.len() { if text_len_end > self.octets.len() {
return Err(io::Error::new( return Err(io::Error::new(
@ -928,10 +921,13 @@ impl ErrorReport {
let text_len = u32::from_be_bytes( let text_len = u32::from_be_bytes(
self.octets[text_len_offset..text_len_end] self.octets[text_len_offset..text_len_end]
.try_into() .try_into()
.unwrap() .unwrap(),
) as usize; ) as usize;
let Some(text_end) = text_len_end.checked_add(text_len) else { let Some(text_end) = text_len_end.checked_add(text_len) else {
return Err(io::Error::new(io::ErrorKind::InvalidData, "ErrorReport text overflow")); return Err(io::Error::new(
io::ErrorKind::InvalidData,
"ErrorReport text overflow",
));
}; };
if text_end != self.octets.len() { if text_end != self.octets.len() {
return Err(io::Error::new( return Err(io::Error::new(
@ -951,7 +947,6 @@ impl ErrorReport {
} }
} }
// TODO: 补全 // TODO: 补全
/// Router Key /// Router Key
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
@ -966,13 +961,10 @@ pub struct RouterKey {
} }
impl RouterKey { impl RouterKey {
pub const PDU: u8 = 9; pub const PDU: u8 = 9;
const BASE_LEN: usize = HEADER_LEN + 20 + 4; const BASE_LEN: usize = HEADER_LEN + 20 + 4;
pub async fn read<Sock: AsyncRead + Unpin>( pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
sock: &mut Sock
) -> Result<Self, io::Error> {
let header = HeaderWithFlags::read(sock) let header = HeaderWithFlags::read(sock)
.await .await
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?; .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?;
@ -1022,25 +1014,19 @@ impl RouterKey {
Ok(res) Ok(res)
} }
pub async fn write<A: AsyncWrite + Unpin>( pub async fn write<A: AsyncWrite + Unpin>(&self, w: &mut A) -> Result<(), io::Error> {
&self,
w: &mut A,
) -> Result<(), io::Error> {
let length = Self::BASE_LEN + self.subject_public_key_info.len(); let length = Self::BASE_LEN + self.subject_public_key_info.len();
let header = HeaderWithFlags::new( let header =
self.header.version(), HeaderWithFlags::new(self.header.version(), Self::PDU, self.flags, length as u32);
Self::PDU,
self.flags,
length as u32,
);
w.write_all(&[ w.write_all(&[
header.version(), header.version(),
header.pdu(), header.pdu(),
header.flags().into_u8(), header.flags().into_u8(),
ZERO_8, ZERO_8,
]).await?; ])
.await?;
w.write_all(&(length as u32).to_be_bytes()).await?; w.write_all(&(length as u32).to_be_bytes()).await?;
w.write_all(self.ski.as_ref()).await?; w.write_all(self.ski.as_ref()).await?;
@ -1120,24 +1106,20 @@ impl RouterKey {
} }
} }
// ASPA // ASPA
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)] #[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize)]
pub struct Aspa { pub struct Aspa {
header: HeaderWithFlags, header: HeaderWithFlags,
customer_asn: u32, customer_asn: u32,
provider_asns: Vec<u32> provider_asns: Vec<u32>,
} }
impl Aspa { impl Aspa {
pub const PDU: u8 = 11; pub const PDU: u8 = 11;
const BASE_LEN: usize = HEADER_LEN + 4; const BASE_LEN: usize = HEADER_LEN + 4;
pub async fn read<Sock: AsyncRead + Unpin>( pub async fn read<Sock: AsyncRead + Unpin>(sock: &mut Sock) -> Result<Self, io::Error> {
sock: &mut Sock
) -> Result<Self, io::Error> {
let header = HeaderWithFlags::read(sock) let header = HeaderWithFlags::read(sock)
.await .await
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?; .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.to_string()))?;
@ -1192,11 +1174,7 @@ impl Aspa {
Ok(res) Ok(res)
} }
pub async fn write<A: AsyncWrite + Unpin>( pub async fn write<A: AsyncWrite + Unpin>(&self, w: &mut A) -> Result<(), io::Error> {
&self,
w: &mut A,
) -> Result<(), io::Error> {
let length = Self::BASE_LEN + (self.provider_asns.len() * 4); let length = Self::BASE_LEN + (self.provider_asns.len() * 4);
let header = HeaderWithFlags::new( let header = HeaderWithFlags::new(
@ -1211,7 +1189,8 @@ impl Aspa {
header.pdu(), header.pdu(),
header.flags().into_u8(), header.flags().into_u8(),
ZERO_8, ZERO_8,
]).await?; ])
.await?;
w.write_all(&(length as u32).to_be_bytes()).await?; w.write_all(&(length as u32).to_be_bytes()).await?;
w.write_all(&self.customer_asn.to_be_bytes()).await?; w.write_all(&self.customer_asn.to_be_bytes()).await?;
@ -1222,12 +1201,7 @@ impl Aspa {
Ok(()) Ok(())
} }
pub fn new( pub fn new(version: u8, flags: Flags, customer_asn: u32, provider_asns: Vec<u32>) -> Self {
version: u8,
flags: Flags,
customer_asn: u32,
provider_asns: Vec<u32>,
) -> Self {
let length = Self::BASE_LEN + (provider_asns.len() * 4); let length = Self::BASE_LEN + (provider_asns.len() * 4);
Self { Self {
@ -1306,7 +1280,6 @@ impl Aspa {
} }
} }
//--- AsRef and AsMut //--- AsRef and AsMut
impl AsRef<[u8]> for ErrorReport { impl AsRef<[u8]> for ErrorReport {
fn as_ref(&self) -> &[u8] { fn as_ref(&self) -> &[u8] {
@ -1319,4 +1292,3 @@ impl AsMut<[u8]> for ErrorReport {
self.octets.as_mut() self.octets.as_mut()
} }
} }

View File

@ -6,7 +6,7 @@ use std::sync::{
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result, anyhow};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::{broadcast, watch, OwnedSemaphorePermit}; use tokio::sync::{OwnedSemaphorePermit, broadcast, watch};
use tracing::{info, warn}; use tracing::{info, warn};
use x509_parser::extensions::GeneralName; use x509_parser::extensions::GeneralName;
use x509_parser::prelude::{FromDer, X509Certificate}; use x509_parser::prelude::{FromDer, X509Certificate};
@ -22,10 +22,7 @@ pub struct ConnectionGuard {
} }
impl ConnectionGuard { impl ConnectionGuard {
pub fn new( pub fn new(active_connections: Arc<AtomicUsize>, permit: OwnedSemaphorePermit) -> Self {
active_connections: Arc<AtomicUsize>,
permit: OwnedSemaphorePermit,
) -> Self {
active_connections.fetch_add(1, Ordering::Relaxed); active_connections.fetch_add(1, Ordering::Relaxed);
Self { Self {
active_connections, active_connections,
@ -72,8 +69,12 @@ pub async fn handle_tls_connection(
.await .await
.with_context(|| format!("TLS handshake failed for {}", peer_addr))?; .with_context(|| format!("TLS handshake failed for {}", peer_addr))?;
info!("RTR TLS handshake completed for {}", peer_addr); info!("RTR TLS handshake completed for {}", peer_addr);
verify_peer_certificate_ip(&tls_stream, peer_addr.ip()) verify_peer_certificate_ip(&tls_stream, peer_addr.ip()).with_context(|| {
.with_context(|| format!("TLS client certificate SAN IP validation failed for {}", peer_addr))?; format!(
"TLS client certificate SAN IP validation failed for {}",
peer_addr
)
})?;
info!("RTR TLS client certificate validated for {}", peer_addr); info!("RTR TLS client certificate validated for {}", peer_addr);
let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx); let session = RtrSession::new(cache, tls_stream, notify_rx, shutdown_rx);

View File

@ -1,28 +1,22 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::Path; use std::path::Path;
use std::sync::{ use std::sync::{Arc, atomic::AtomicUsize};
Arc,
atomic::AtomicUsize,
};
use std::time::Duration; use std::time::Duration;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use socket2::{SockRef, TcpKeepalive}; use socket2::{SockRef, TcpKeepalive};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::{broadcast, watch, Semaphore}; use tokio::sync::{Semaphore, broadcast, watch};
use tracing::{info, warn}; use tracing::{info, warn};
use rustls::ServerConfig; use rustls::ServerConfig;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::rtr::cache::SharedRtrCache; use crate::rtr::cache::SharedRtrCache;
use crate::rtr::server::connection::{
ConnectionGuard,
handle_tcp_connection,
handle_tls_connection,
is_expected_disconnect,
};
use crate::rtr::server::config::RtrServiceConfig; use crate::rtr::server::config::RtrServiceConfig;
use crate::rtr::server::connection::{
ConnectionGuard, handle_tcp_connection, handle_tls_connection, is_expected_disconnect,
};
use crate::rtr::server::tls::load_rustls_server_config_with_options; use crate::rtr::server::tls::load_rustls_server_config_with_options;
pub struct RtrServer { pub struct RtrServer {
@ -65,7 +59,8 @@ impl RtrServer {
} }
pub fn active_connections(&self) -> usize { pub fn active_connections(&self) -> usize {
self.active_connections.load(std::sync::atomic::Ordering::Relaxed) self.active_connections
.load(std::sync::atomic::Ordering::Relaxed)
} }
pub async fn run_tcp(self) -> Result<()> { pub async fn run_tcp(self) -> Result<()> {
@ -293,10 +288,7 @@ impl RtrServer {
} }
} }
fn apply_keepalive( fn apply_keepalive(stream: &tokio::net::TcpStream, keepalive: Option<Duration>) -> Result<()> {
stream: &tokio::net::TcpStream,
keepalive: Option<Duration>,
) -> Result<()> {
let Some(keepalive) = keepalive else { let Some(keepalive) = keepalive else {
return Ok(()); return Ok(());
}; };

View File

@ -5,7 +5,7 @@ use std::sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
}; };
use tokio::sync::{broadcast, watch, Semaphore}; use tokio::sync::{Semaphore, broadcast, watch};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tracing::{error, warn}; use tracing::{error, warn};
@ -114,7 +114,10 @@ impl RtrService {
let server = self.tls_server(bind_addr); let server = self.tls_server(bind_addr);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = server.run_tls_from_pem(cert_path, key_path, client_ca_path).await { if let Err(err) = server
.run_tls_from_pem(cert_path, key_path, client_ca_path)
.await
{
error!("RTR TLS server {} exited with error: {:?}", bind_addr, err); error!("RTR TLS server {} exited with error: {:?}", bind_addr, err);
} }
}) })
@ -129,7 +132,8 @@ impl RtrService {
client_ca_path: impl AsRef<Path>, client_ca_path: impl AsRef<Path>,
) -> RunningRtrService { ) -> RunningRtrService {
let tcp_handle = self.spawn_tcp(tcp_bind_addr); let tcp_handle = self.spawn_tcp(tcp_bind_addr);
let tls_handle = self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path, client_ca_path); let tls_handle =
self.spawn_tls_from_pem(tls_bind_addr, cert_path, key_path, client_ca_path);
RunningRtrService { RunningRtrService {
shutdown_tx: self.shutdown_tx.clone(), shutdown_tx: self.shutdown_tx.clone(),

View File

@ -3,7 +3,7 @@ use std::io::BufReader;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, Context, Result}; use anyhow::{Context, Result, anyhow};
use rustls::server::WebPkiClientVerifier; use rustls::server::WebPkiClientVerifier;
use rustls::{RootCertStore, ServerConfig}; use rustls::{RootCertStore, ServerConfig};
use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use rustls_pki_types::{CertificateDer, PrivateKeyDer};
@ -36,8 +36,12 @@ pub fn load_rustls_server_config_with_options(
let key = load_private_key(&key_path) let key = load_private_key(&key_path)
.with_context(|| format!("failed to load private key from {}", key_path.display()))?; .with_context(|| format!("failed to load private key from {}", key_path.display()))?;
let client_ca_certs = load_certs(&client_ca_path) let client_ca_certs = load_certs(&client_ca_path).with_context(|| {
.with_context(|| format!("failed to load client CA certs from {}", client_ca_path.display()))?; format!(
"failed to load client CA certs from {}",
client_ca_path.display()
)
})?;
let mut client_roots = RootCertStore::empty(); let mut client_roots = RootCertStore::empty();
let (added, _) = client_roots.add_parsable_certificates(client_ca_certs); let (added, _) = client_roots.add_parsable_certificates(client_ca_certs);
if added == 0 { if added == 0 {
@ -100,8 +104,7 @@ fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
let file = File::open(path)?; let file = File::open(path)?;
let mut reader = BufReader::new(file); let mut reader = BufReader::new(file);
let certs = rustls_pemfile::certs(&mut reader) let certs = rustls_pemfile::certs(&mut reader).collect::<std::result::Result<Vec<_>, _>>()?;
.collect::<std::result::Result<Vec<_>, _>>()?;
if certs.is_empty() { if certs.is_empty() {
return Err(anyhow!("no certificates found in {}", path.display())); return Err(anyhow!("no certificates found in {}", path.display()));

View File

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use anyhow::{anyhow, bail, Result}; use anyhow::{Result, anyhow, bail};
use tokio::io; use tokio::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{broadcast, watch}; use tokio::sync::{broadcast, watch};
@ -14,13 +14,11 @@ use crate::rtr::cache::{
validate_payloads_for_rtr, validate_payloads_for_rtr,
}; };
use crate::rtr::error_type::ErrorCode; use crate::rtr::error_type::ErrorCode;
use crate::rtr::pdu::{
Aspa as AspaPdu,
CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, Header, IPv4Prefix, IPv6Prefix,
ResetQuery, RouterKey as RouterKeyPdu, SerialNotify, SerialQuery,
HEADER_LEN,
};
use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey}; use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey};
use crate::rtr::pdu::{
Aspa as AspaPdu, CacheReset, CacheResponse, EndOfData, ErrorReport, Flags, HEADER_LEN, Header,
IPv4Prefix, IPv6Prefix, ResetQuery, RouterKey as RouterKeyPdu, SerialNotify, SerialQuery,
};
const SUPPORTED_MAX_VERSION: u8 = 2; const SUPPORTED_MAX_VERSION: u8 = 2;
const SUPPORTED_MIN_VERSION: u8 = 0; const SUPPORTED_MIN_VERSION: u8 = 0;
@ -83,10 +81,7 @@ where
} }
async fn run_inner(&mut self) -> Result<()> { async fn run_inner(&mut self) -> Result<()> {
info!( info!("RTR session started: {}", self.session_summary());
"RTR session started: {}",
self.session_summary()
);
loop { loop {
let transport_timeout = self.transport_timeout(); let transport_timeout = self.transport_timeout();
tokio::select! { tokio::select! {
@ -265,7 +260,10 @@ where
self.session_summary() self.session_summary()
); );
} else { } else {
debug!("RTR session transport shutdown completed: {}", self.session_summary()); debug!(
"RTR session transport shutdown completed: {}",
self.session_summary()
);
} }
} }
@ -362,8 +360,7 @@ where
) -> io::Result<()> { ) -> io::Result<()> {
let msg = format!( let msg = format!(
"unexpected protocol version {}, established version is {}", "unexpected protocol version {}, established version is {}",
received_version, received_version, established_version
established_version
); );
self.send_error( self.send_error(
@ -402,12 +399,7 @@ where
offending_pdu: &[u8], offending_pdu: &[u8],
detail: &[u8], detail: &[u8],
) -> io::Result<()> { ) -> io::Result<()> {
self.send_error( self.send_error(version, ErrorCode::CorruptData, offending_pdu, detail)
version,
ErrorCode::CorruptData,
offending_pdu,
detail,
)
.await .await
} }
@ -458,7 +450,8 @@ where
self.state = SessionState::Closed; self.state = SessionState::Closed;
bail!( bail!(
"router version {} higher than cache max {}", "router version {} higher than cache max {}",
version, SUPPORTED_MAX_VERSION version,
SUPPORTED_MAX_VERSION
); );
} }
self.negotiate_version(version).await?; self.negotiate_version(version).await?;
@ -495,14 +488,16 @@ where
self.state = SessionState::Closed; self.state = SessionState::Closed;
bail!( bail!(
"router version {} higher than cache max {}", "router version {} higher than cache max {}",
version, SUPPORTED_MAX_VERSION version,
SUPPORTED_MAX_VERSION
); );
} }
self.negotiate_version(version).await?; self.negotiate_version(version).await?;
let session_id = query.session_id(); let session_id = query.session_id();
let serial = query.serial_number(); let serial = query.serial_number();
self.handle_serial(version, session_id, serial, query.as_ref()).await?; self.handle_serial(version, session_id, serial, query.as_ref())
.await?;
self.state = SessionState::Established; self.state = SessionState::Established;
info!( info!(
"RTR session established after Serial Query: negotiated_version={}, client_session_id={}, client_serial={}, {}", "RTR session established after Serial Query: negotiated_version={}, client_session_id={}, client_serial={}, {}",
@ -613,7 +608,10 @@ where
.cache .cache
.read() .read()
.map_err(|_| anyhow!("cache read lock poisoned"))?; .map_err(|_| anyhow!("cache read lock poisoned"))?;
(cache.is_data_available(), cache.session_id_for_version(version)) (
cache.is_data_available(),
cache.session_id_for_version(version),
)
}; };
if !data_available { if !data_available {
@ -723,7 +721,10 @@ where
let now = Instant::now(); let now = Instant::now();
if let Some(last) = self.last_notify_at { if let Some(last) = self.last_notify_at {
if now.duration_since(last) < NOTIFY_MIN_INTERVAL { if now.duration_since(last) < NOTIFY_MIN_INTERVAL {
debug!("RTR session notify skipped due to rate limit: {}", self.session_summary()); debug!(
"RTR session notify skipped due to rate limit: {}",
self.session_summary()
);
return Ok(()); return Ok(());
} }
} }
@ -824,8 +825,7 @@ where
let version = self.version()?; let version = self.version()?;
debug!( debug!(
"RTR session writing Cache Response: version={}, session_id={}", "RTR session writing Cache Response: version={}, session_id={}",
version, version, session_id
session_id
); );
CacheResponse::new(version, session_id) CacheResponse::new(version, session_id)
.write(&mut self.stream) .write(&mut self.stream)
@ -835,10 +835,7 @@ where
async fn write_cache_reset(&mut self) -> Result<()> { async fn write_cache_reset(&mut self) -> Result<()> {
let version = self.version()?; let version = self.version()?;
info!( info!("RTR session writing Cache Reset: version={}", version);
"RTR session writing Cache Reset: version={}",
version
);
CacheReset::new(version).write(&mut self.stream).await?; CacheReset::new(version).write(&mut self.stream).await?;
Ok(()) Ok(())
} }
@ -880,8 +877,7 @@ where
// References: // References:
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12
validate_payloads_for_rtr(payloads, announce) validate_payloads_for_rtr(payloads, announce).map_err(|err| anyhow!(err.to_string()))?;
.map_err(|err| anyhow!(err.to_string()))?;
let (route_origins, router_keys, aspas) = count_payloads(payloads); let (route_origins, router_keys, aspas) = count_payloads(payloads);
debug!( debug!(
"RTR session sending snapshot payloads: announce={}, total={}, route_origins={}, router_keys={}, aspas={}", "RTR session sending snapshot payloads: announce={}, total={}, route_origins={}, router_keys={}, aspas={}",
@ -906,8 +902,7 @@ where
// References: // References:
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12
validate_payload_updates_for_rtr(&updates) validate_payload_updates_for_rtr(&updates).map_err(|err| anyhow!(err.to_string()))?;
.map_err(|err| anyhow!(err.to_string()))?;
let (announced, withdrawn, route_origins, router_keys, aspas) = let (announced, withdrawn, route_origins, router_keys, aspas) =
count_payload_updates(&updates); count_payload_updates(&updates);
debug!( debug!(
@ -1010,8 +1005,7 @@ where
}); });
let providers = if announce { let providers = if announce {
aspa aspa.provider_asns()
.provider_asns()
.iter() .iter()
.map(|asn| asn.into_u32()) .map(|asn| asn.into_u32())
.collect::<Vec<_>>() .collect::<Vec<_>>()
@ -1019,18 +1013,12 @@ where
Vec::new() Vec::new()
}; };
let pdu = AspaPdu::new( let pdu = AspaPdu::new(version, flags, aspa.customer_asn().into_u32(), providers);
version,
flags,
aspa.customer_asn().into_u32(),
providers,
);
pdu.write(&mut self.stream).await?; pdu.write(&mut self.stream).await?;
Ok(()) Ok(())
} }
async fn send_error( async fn send_error(
&mut self, &mut self,
version: u8, version: u8,
@ -1052,11 +1040,7 @@ where
.await .await
} }
async fn handle_pdu_read_error( async fn handle_pdu_read_error(&mut self, header: Header, err: io::Error) -> Result<()> {
&mut self,
header: Header,
err: io::Error,
) -> Result<()> {
warn!( warn!(
"RTR session failed to read established-session PDU payload: pdu={}, version={}, err={}", "RTR session failed to read established-session PDU payload: pdu={}, version={}, err={}",
header.pdu(), header.pdu(),
@ -1076,11 +1060,7 @@ where
Ok(()) Ok(())
} }
async fn handle_first_pdu_read_error( async fn handle_first_pdu_read_error(&mut self, header: Header, err: io::Error) -> Result<()> {
&mut self,
header: Header,
err: io::Error,
) -> Result<()> {
warn!( warn!(
"RTR session failed to read first PDU payload: pdu={}, version={}, err={}", "RTR session failed to read first PDU payload: pdu={}, version={}, err={}",
header.pdu(), header.pdu(),
@ -1089,9 +1069,8 @@ where
); );
if err.kind() == io::ErrorKind::InvalidData { if err.kind() == io::ErrorKind::InvalidData {
let offending = self.read_full_pdu_bytes(header).await?; let offending = self.read_full_pdu_bytes(header).await?;
let err_version = if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION) let err_version =
.contains(&header.version()) if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION).contains(&header.version()) {
{
header.version() header.version()
} else { } else {
SUPPORTED_MAX_VERSION SUPPORTED_MAX_VERSION
@ -1114,13 +1093,14 @@ where
) -> Result<()> { ) -> Result<()> {
warn!( warn!(
"RTR session handling invalid header bytes: raw_header={:02X?}, err={}", "RTR session handling invalid header bytes: raw_header={:02X?}, err={}",
raw_header, raw_header, err
err
); );
if err.kind() == io::ErrorKind::InvalidData { if err.kind() == io::ErrorKind::InvalidData {
let version = match self.version { let version = match self.version {
Some(version) => version, Some(version) => version,
None if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION).contains(&raw_header[0]) => { None if (SUPPORTED_MIN_VERSION..=SUPPORTED_MAX_VERSION)
.contains(&raw_header[0]) =>
{
raw_header[0] raw_header[0]
} }
None => SUPPORTED_MAX_VERSION, None => SUPPORTED_MAX_VERSION,
@ -1139,10 +1119,7 @@ where
async fn handle_transport_timeout(&mut self, offending_pdu: &[u8]) -> Result<()> { async fn handle_transport_timeout(&mut self, offending_pdu: &[u8]) -> Result<()> {
let version = self.version.unwrap_or(SUPPORTED_MAX_VERSION); let version = self.version.unwrap_or(SUPPORTED_MAX_VERSION);
let timeout = self.transport_timeout(); let timeout = self.transport_timeout();
let detail = format!( let detail = format!("transport stalled for longer than {:?}", timeout);
"transport stalled for longer than {:?}",
timeout
);
warn!( warn!(
"RTR session transport timeout: version={}, offending_pdu_len={}, timeout={:?}", "RTR session transport timeout: version={}, offending_pdu_len={}, timeout={:?}",
version, version,
@ -1177,7 +1154,8 @@ where
bytes.resize(total_len, 0); bytes.resize(total_len, 0);
timeout( timeout(
self.transport_timeout(), self.transport_timeout(),
self.stream.read_exact(&mut bytes[HEADER_LEN..HEADER_LEN + payload_len]), self.stream
.read_exact(&mut bytes[HEADER_LEN..HEADER_LEN + payload_len]),
) )
.await .await
.map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "transport read timed out"))??; .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "transport read timed out"))??;

View File

@ -1,6 +1,6 @@
use anyhow::{Result, anyhow};
use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch}; use rocksdb::{ColumnFamilyDescriptor, DB, Direction, IteratorMode, Options, WriteBatch};
use anyhow::{anyhow, Result}; use serde::{Serialize, de::DeserializeOwned};
use serde::{de::DeserializeOwned, Serialize};
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tokio::task; use tokio::task;
@ -66,7 +66,10 @@ impl RtrStore {
/// Common serialize/put. /// Common serialize/put.
fn put_cf<T: Serialize>(&self, cf: &str, key: &[u8], value: &T) -> Result<()> { fn put_cf<T: Serialize>(&self, cf: &str, key: &[u8], value: &T) -> Result<()> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; let cf_handle = self
.db
.cf_handle(cf)
.ok_or_else(|| anyhow!("CF not found"))?;
let data = serde_json::to_vec(value)?; let data = serde_json::to_vec(value)?;
self.db.put_cf(cf_handle, key, data)?; self.db.put_cf(cf_handle, key, data)?;
Ok(()) Ok(())
@ -74,7 +77,10 @@ impl RtrStore {
/// Common get/deserialize. /// Common get/deserialize.
fn get_cf<T: DeserializeOwned>(&self, cf: &str, key: &[u8]) -> Result<Option<T>> { fn get_cf<T: DeserializeOwned>(&self, cf: &str, key: &[u8]) -> Result<Option<T>> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; let cf_handle = self
.db
.cf_handle(cf)
.ok_or_else(|| anyhow!("CF not found"))?;
if let Some(value) = self.db.get_cf(cf_handle, key)? { if let Some(value) = self.db.get_cf(cf_handle, key)? {
let obj = serde_json::from_slice(&value)?; let obj = serde_json::from_slice(&value)?;
Ok(Some(obj)) Ok(Some(obj))
@ -85,7 +91,10 @@ impl RtrStore {
/// Common delete. /// Common delete.
fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> { fn delete_cf(&self, cf: &str, key: &[u8]) -> Result<()> {
let cf_handle = self.db.cf_handle(cf).ok_or_else(|| anyhow!("CF not found"))?; let cf_handle = self
.db
.cf_handle(cf)
.ok_or_else(|| anyhow!("CF not found"))?;
self.db.delete_cf(cf_handle, key)?; self.db.delete_cf(cf_handle, key)?;
Ok(()) Ok(())
} }
@ -137,10 +146,12 @@ impl RtrStore {
pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> { pub fn set_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<()> {
debug!( debug!(
"RTR store persisting delta window metadata: min_serial={}, max_serial={}", "RTR store persisting delta window metadata: min_serial={}, max_serial={}",
min_serial, min_serial, max_serial
max_serial
); );
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; let meta_cf = self
.db
.cf_handle(CF_META)
.ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?); batch.put_cf(meta_cf, META_DELTA_MIN, serde_json::to_vec(&min_serial)?);
batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?); batch.put_cf(meta_cf, META_DELTA_MAX, serde_json::to_vec(&max_serial)?);
@ -150,7 +161,10 @@ impl RtrStore {
pub fn clear_delta_window(&self) -> Result<()> { pub fn clear_delta_window(&self) -> Result<()> {
debug!("RTR store clearing delta window metadata"); debug!("RTR store clearing delta window metadata");
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; let meta_cf = self
.db
.cf_handle(CF_META)
.ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
batch.delete_cf(meta_cf, META_DELTA_MIN); batch.delete_cf(meta_cf, META_DELTA_MIN);
batch.delete_cf(meta_cf, META_DELTA_MAX); batch.delete_cf(meta_cf, META_DELTA_MAX);
@ -166,8 +180,7 @@ impl RtrStore {
(Some(min), Some(max)) => { (Some(min), Some(max)) => {
debug!( debug!(
"RTR store loaded delta window metadata: min_serial={}, max_serial={}", "RTR store loaded delta window metadata: min_serial={}, max_serial={}",
min, min, max
max
); );
Ok(Some((min, max))) Ok(Some((min, max)))
} }
@ -189,7 +202,10 @@ impl RtrStore {
// =============================== // ===============================
pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> { pub fn save_snapshot(&self, snapshot: &Snapshot) -> Result<()> {
let cf_handle = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; let cf_handle = self
.db
.cf_handle(CF_SNAPSHOT)
.ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
let data = serde_json::to_vec(snapshot)?; let data = serde_json::to_vec(snapshot)?;
batch.put_cf(cf_handle, b"current", data); batch.put_cf(cf_handle, b"current", data);
@ -206,8 +222,14 @@ impl RtrStore {
} }
pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> { pub fn save_snapshot_and_state(&self, snapshot: &Snapshot, state: &State) -> Result<()> {
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; let snapshot_cf = self
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; .db
.cf_handle(CF_SNAPSHOT)
.ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self
.db
.cf_handle(CF_META)
.ok_or_else(|| anyhow!("CF_META not found"))?;
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
@ -234,8 +256,14 @@ impl RtrStore {
serial: u32, serial: u32,
) -> Result<()> { ) -> Result<()> {
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; let snapshot_cf = self
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; .db
.cf_handle(CF_SNAPSHOT)
.ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self
.db
.cf_handle(CF_META)
.ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?); batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?);
@ -266,15 +294,28 @@ impl RtrStore {
snapshot.router_keys().len(), snapshot.router_keys().len(),
snapshot.aspas().len() snapshot.aspas().len()
); );
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; let snapshot_cf = self
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; .db
let delta_cf = self.db.cf_handle(CF_DELTA).ok_or_else(|| anyhow!("CF_DELTA not found"))?; .cf_handle(CF_SNAPSHOT)
.ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self
.db
.cf_handle(CF_META)
.ok_or_else(|| anyhow!("CF_META not found"))?;
let delta_cf = self
.db
.cf_handle(CF_DELTA)
.ok_or_else(|| anyhow!("CF_DELTA not found"))?;
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?); batch.put_cf(meta_cf, META_SESSION_IDS, serde_json::to_vec(session_ids)?);
batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?);
batch.put_cf(meta_cf, META_AVAILABILITY, serde_json::to_vec(&availability)?); batch.put_cf(
meta_cf,
META_AVAILABILITY,
serde_json::to_vec(&availability)?,
);
if let Some(delta) = delta { if let Some(delta) = delta {
debug!( debug!(
@ -283,7 +324,11 @@ impl RtrStore {
delta.announced().len(), delta.announced().len(),
delta.withdrawn().len() delta.withdrawn().len()
); );
batch.put_cf(delta_cf, delta_key(delta.serial()), serde_json::to_vec(delta)?); batch.put_cf(
delta_cf,
delta_key(delta.serial()),
serde_json::to_vec(delta)?,
);
} }
if clear_delta_window { if clear_delta_window {
@ -318,8 +363,7 @@ impl RtrStore {
} else { } else {
debug!( debug!(
"RTR store found no stale delta records outside window [{}, {}]", "RTR store found no stale delta records outside window [{}, {}]",
min_serial, min_serial, max_serial
max_serial
); );
} }
for key in stale_keys { for key in stale_keys {
@ -334,8 +378,14 @@ impl RtrStore {
pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> { pub fn save_snapshot_and_serial(&self, snapshot: &Snapshot, serial: u32) -> Result<()> {
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; let snapshot_cf = self
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; .db
.cf_handle(CF_SNAPSHOT)
.ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self
.db
.cf_handle(CF_META)
.ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?); batch.put_cf(snapshot_cf, b"current", serde_json::to_vec(snapshot)?);
batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?); batch.put_cf(meta_cf, META_SERIAL, serde_json::to_vec(&serial)?);
self.db.write(batch)?; self.db.write(batch)?;
@ -352,8 +402,14 @@ impl RtrStore {
task::spawn_blocking(move || { task::spawn_blocking(move || {
let mut batch = WriteBatch::default(); let mut batch = WriteBatch::default();
let snapshot_cf = self.db.cf_handle(CF_SNAPSHOT).ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?; let snapshot_cf = self
let meta_cf = self.db.cf_handle(CF_META).ok_or_else(|| anyhow!("CF_META not found"))?; .db
.cf_handle(CF_SNAPSHOT)
.ok_or_else(|| anyhow!("CF_SNAPSHOT not found"))?;
let meta_cf = self
.db
.cf_handle(CF_META)
.ok_or_else(|| anyhow!("CF_META not found"))?;
batch.put_cf(snapshot_cf, b"current", snapshot_bytes); batch.put_cf(snapshot_cf, b"current", snapshot_bytes);
batch.put_cf(meta_cf, META_SERIAL, serial_bytes); batch.put_cf(meta_cf, META_SERIAL, serial_bytes);
self.db.write(batch)?; self.db.write(batch)?;
@ -370,7 +426,9 @@ impl RtrStore {
match (snapshot, state) { match (snapshot, state) {
(Some(snap), Some(state)) => Ok(Some((snap, state))), (Some(snap), Some(state)) => Ok(Some((snap, state))),
(None, None) => Ok(None), (None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: snapshot and state mismatch")), _ => Err(anyhow!(
"Inconsistent DB state: snapshot and state mismatch"
)),
} }
} }
@ -380,7 +438,9 @@ impl RtrStore {
match (snapshot, serial) { match (snapshot, serial) {
(Some(snap), Some(serial)) => Ok(Some((snap, serial))), (Some(snap), Some(serial)) => Ok(Some((snap, serial))),
(None, None) => Ok(None), (None, None) => Ok(None),
_ => Err(anyhow!("Inconsistent DB state: snapshot and serial mismatch")), _ => Err(anyhow!(
"Inconsistent DB state: snapshot and serial mismatch"
)),
} }
} }
@ -413,8 +473,8 @@ impl RtrStore {
for item in iter { for item in iter {
let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let parsed = delta_key_serial(key.as_ref()) let parsed =
.ok_or_else(|| anyhow!("Invalid delta key"))?; delta_key_serial(key.as_ref()).ok_or_else(|| anyhow!("Invalid delta key"))?;
if parsed <= serial { if parsed <= serial {
continue; continue;
@ -430,8 +490,7 @@ impl RtrStore {
pub fn load_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<Vec<Delta>> { pub fn load_delta_window(&self, min_serial: u32, max_serial: u32) -> Result<Vec<Delta>> {
info!( info!(
"RTR store loading persisted delta window: min_serial={}, max_serial={}", "RTR store loading persisted delta window: min_serial={}, max_serial={}",
min_serial, min_serial, max_serial
max_serial
); );
let cf_handle = self let cf_handle = self
.db .db
@ -442,8 +501,8 @@ impl RtrStore {
for item in iter { for item in iter {
let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; let (key, value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let parsed = delta_key_serial(key.as_ref()) let parsed =
.ok_or_else(|| anyhow!("Invalid delta key"))?; delta_key_serial(key.as_ref()).ok_or_else(|| anyhow!("Invalid delta key"))?;
// Restore by the persisted window bounds instead of load_deltas_since(). // Restore by the persisted window bounds instead of load_deltas_since().
// The latter follows lexicographic key order and is not safe across serial // The latter follows lexicographic key order and is not safe across serial
@ -493,7 +552,11 @@ impl RtrStore {
Ok(keys) Ok(keys)
} }
fn list_delta_keys_outside_window(&self, min_serial: u32, max_serial: u32) -> Result<Vec<Vec<u8>>> { fn list_delta_keys_outside_window(
&self,
min_serial: u32,
max_serial: u32,
) -> Result<Vec<Vec<u8>>> {
let cf_handle = self let cf_handle = self
.db .db
.cf_handle(CF_DELTA) .cf_handle(CF_DELTA)
@ -503,8 +566,8 @@ impl RtrStore {
for item in iter { for item in iter {
let (key, _value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?; let (key, _value) = item.map_err(|e| anyhow!("rocksdb iterator error: {}", e))?;
let serial = delta_key_serial(key.as_ref()) let serial =
.ok_or_else(|| anyhow!("Invalid delta key"))?; delta_key_serial(key.as_ref()).ok_or_else(|| anyhow!("Invalid delta key"))?;
if !serial_in_window(serial, min_serial, max_serial) { if !serial_in_window(serial, min_serial, max_serial) {
keys.push(key.to_vec()); keys.push(key.to_vec());
} }
@ -522,8 +585,7 @@ fn validate_delta_window(deltas: &[Delta], min_serial: u32, max_serial: u32) ->
if deltas.is_empty() { if deltas.is_empty() {
warn!( warn!(
"RTR store delta window validation failed: no persisted deltas for window [{}, {}]", "RTR store delta window validation failed: no persisted deltas for window [{}, {}]",
min_serial, min_serial, max_serial
max_serial
); );
return Err(anyhow!( return Err(anyhow!(
"delta window [{}, {}] has no persisted deltas", "delta window [{}, {}] has no persisted deltas",

251
src/slurm/file.rs Normal file
View File

@ -0,0 +1,251 @@
use std::collections::BTreeSet;
use std::io;
use crate::rtr::payload::Payload;
use crate::slurm::policy::{LocallyAddedAssertions, ValidationOutputFilters, prefix_encompasses};
#[derive(Debug, thiserror::Error)]
pub enum SlurmError {
#[error("failed to parse SLURM JSON: {0}")]
Parse(#[from] serde_json::Error),
#[error("invalid SLURM file: {0}")]
Invalid(String),
#[error("I/O error while reading SLURM file: {0}")]
Io(#[from] io::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SlurmVersion {
V1,
V2,
}
impl SlurmVersion {
pub const V1_U32: u32 = 1;
pub const V2_U32: u32 = 2;
pub fn as_u32(self) -> u32 {
match self {
Self::V1 => Self::V1_U32,
Self::V2 => Self::V2_U32,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SlurmFile {
version: SlurmVersion,
validation_output_filters: ValidationOutputFilters,
locally_added_assertions: LocallyAddedAssertions,
}
impl SlurmFile {
pub fn new(
version: SlurmVersion,
validation_output_filters: ValidationOutputFilters,
locally_added_assertions: LocallyAddedAssertions,
) -> Result<Self, SlurmError> {
let slurm = Self {
version,
validation_output_filters,
locally_added_assertions,
};
slurm.validate()?;
Ok(slurm)
}
pub fn version(&self) -> SlurmVersion {
self.version
}
pub fn validation_output_filters(&self) -> &ValidationOutputFilters {
&self.validation_output_filters
}
pub fn locally_added_assertions(&self) -> &LocallyAddedAssertions {
&self.locally_added_assertions
}
pub fn apply(&self, payloads: &[Payload]) -> Vec<Payload> {
let mut seen = BTreeSet::new();
let mut result = Vec::new();
for payload in payloads {
if self.validation_output_filters.matches(payload) {
continue;
}
if seen.insert(payload.clone()) {
result.push(payload.clone());
}
}
for assertion in self.locally_added_assertions.to_payloads() {
if seen.insert(assertion.clone()) {
result.push(assertion);
}
}
result
}
pub fn merge_named(files: Vec<(String, SlurmFile)>) -> Result<Self, SlurmError> {
if files.is_empty() {
return Err(SlurmError::Invalid(
"SLURM directory does not contain any .slurm files".to_string(),
));
}
validate_cross_file_conflicts(&files)?;
let mut version = SlurmVersion::V1;
let mut merged_filters = ValidationOutputFilters {
prefix_filters: Vec::new(),
bgpsec_filters: Vec::new(),
aspa_filters: Vec::new(),
};
let mut merged_assertions = LocallyAddedAssertions {
prefix_assertions: Vec::new(),
bgpsec_assertions: Vec::new(),
aspa_assertions: Vec::new(),
};
for (_, file) in files {
if file.version == SlurmVersion::V2 {
version = SlurmVersion::V2;
}
merged_filters.extend_from(file.validation_output_filters());
merged_assertions.extend_from(file.locally_added_assertions());
}
Self::new(version, merged_filters, merged_assertions)
}
fn validate(&self) -> Result<(), SlurmError> {
self.validation_output_filters.validate(self.version)?;
self.locally_added_assertions.validate(self.version)?;
Ok(())
}
}
fn validate_cross_file_conflicts(files: &[(String, SlurmFile)]) -> Result<(), SlurmError> {
for i in 0..files.len() {
for j in (i + 1)..files.len() {
let (name_a, file_a) = &files[i];
let (name_b, file_b) = &files[j];
if prefix_spaces_overlap(file_a, file_b) {
return Err(SlurmError::Invalid(format!(
"conflicting SLURM files: '{}' and '{}' have overlapping prefix spaces",
name_a, name_b
)));
}
if let Some(asn) = bgpsec_asn_overlap(file_a, file_b) {
return Err(SlurmError::Invalid(format!(
"conflicting SLURM files: '{}' and '{}' both constrain BGPsec ASN {}",
name_a, name_b, asn
)));
}
if let Some(customer_asn) = aspa_customer_overlap(file_a, file_b) {
return Err(SlurmError::Invalid(format!(
"conflicting SLURM files: '{}' and '{}' both constrain ASPA customerAsn {}",
name_a, name_b, customer_asn
)));
}
}
}
Ok(())
}
fn prefix_spaces_overlap(lhs: &SlurmFile, rhs: &SlurmFile) -> bool {
let mut lhs_prefixes = lhs
.validation_output_filters()
.prefix_filters
.iter()
.filter_map(|f| f.prefix.as_ref())
.chain(
lhs.locally_added_assertions()
.prefix_assertions
.iter()
.map(|a| &a.prefix),
);
let rhs_prefixes = rhs
.validation_output_filters()
.prefix_filters
.iter()
.filter_map(|f| f.prefix.as_ref())
.chain(
rhs.locally_added_assertions()
.prefix_assertions
.iter()
.map(|a| &a.prefix),
)
.collect::<Vec<_>>();
lhs_prefixes.any(|left| {
rhs_prefixes
.iter()
.any(|right| prefix_encompasses(left, right) || prefix_encompasses(right, left))
})
}
fn bgpsec_asn_overlap(lhs: &SlurmFile, rhs: &SlurmFile) -> Option<u32> {
let lhs_asns = lhs
.validation_output_filters()
.bgpsec_filters
.iter()
.filter_map(|f| f.asn)
.chain(
lhs.locally_added_assertions()
.bgpsec_assertions
.iter()
.map(|a| a.asn),
)
.collect::<BTreeSet<_>>();
rhs.validation_output_filters()
.bgpsec_filters
.iter()
.filter_map(|f| f.asn)
.chain(
rhs.locally_added_assertions()
.bgpsec_assertions
.iter()
.map(|a| a.asn),
)
.find_map(|asn| lhs_asns.contains(&asn).then_some(asn.into_u32()))
}
fn aspa_customer_overlap(lhs: &SlurmFile, rhs: &SlurmFile) -> Option<u32> {
let lhs_customers = lhs
.validation_output_filters()
.aspa_filters
.iter()
.map(|f| f.customer_asn)
.chain(
lhs.locally_added_assertions()
.aspa_assertions
.iter()
.map(|a| a.customer_asn),
)
.collect::<BTreeSet<_>>();
rhs.validation_output_filters()
.aspa_filters
.iter()
.map(|f| f.customer_asn)
.chain(
rhs.locally_added_assertions()
.aspa_assertions
.iter()
.map(|a| a.customer_asn),
)
.find_map(|asn| lhs_customers.contains(&asn).then_some(asn.into_u32()))
}

View File

@ -1 +1,3 @@
mod slurm; pub mod file;
pub mod policy;
mod serde;

409
src/slurm/policy.rs Normal file
View File

@ -0,0 +1,409 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
use crate::data_model::resources::as_resources::Asn;
use crate::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
use crate::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Ski};
use crate::slurm::file::{SlurmError, SlurmVersion};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ValidationOutputFilters {
pub prefix_filters: Vec<PrefixFilter>,
pub bgpsec_filters: Vec<BgpsecFilter>,
pub aspa_filters: Vec<AspaFilter>,
}
impl ValidationOutputFilters {
pub(crate) fn extend_from(&mut self, other: &Self) {
self.prefix_filters
.extend(other.prefix_filters.iter().cloned());
self.bgpsec_filters
.extend(other.bgpsec_filters.iter().cloned());
self.aspa_filters.extend(other.aspa_filters.iter().cloned());
}
pub(crate) fn validate(&self, version: SlurmVersion) -> Result<(), SlurmError> {
if version == SlurmVersion::V1 && !self.aspa_filters.is_empty() {
return Err(SlurmError::Invalid(
"slurmVersion 1 must not contain aspaFilters".to_string(),
));
}
for filter in &self.prefix_filters {
filter.validate()?;
}
for filter in &self.bgpsec_filters {
filter.validate()?;
}
for filter in &self.aspa_filters {
filter.validate()?;
}
Ok(())
}
pub(crate) fn matches(&self, payload: &Payload) -> bool {
match payload {
Payload::RouteOrigin(route_origin) => self
.prefix_filters
.iter()
.any(|filter| filter.matches(route_origin)),
Payload::RouterKey(router_key) => self
.bgpsec_filters
.iter()
.any(|filter| filter.matches(router_key)),
Payload::Aspa(aspa) => self.aspa_filters.iter().any(|filter| filter.matches(aspa)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LocallyAddedAssertions {
pub prefix_assertions: Vec<PrefixAssertion>,
pub bgpsec_assertions: Vec<BgpsecAssertion>,
pub aspa_assertions: Vec<AspaAssertion>,
}
impl LocallyAddedAssertions {
pub(crate) fn extend_from(&mut self, other: &Self) {
self.prefix_assertions
.extend(other.prefix_assertions.iter().cloned());
self.bgpsec_assertions
.extend(other.bgpsec_assertions.iter().cloned());
self.aspa_assertions
.extend(other.aspa_assertions.iter().cloned());
}
pub(crate) fn validate(&self, version: SlurmVersion) -> Result<(), SlurmError> {
if version == SlurmVersion::V1 && !self.aspa_assertions.is_empty() {
return Err(SlurmError::Invalid(
"slurmVersion 1 must not contain aspaAssertions".to_string(),
));
}
for assertion in &self.prefix_assertions {
assertion.validate()?;
}
for assertion in &self.bgpsec_assertions {
assertion.validate()?;
}
for assertion in &self.aspa_assertions {
assertion.validate()?;
}
Ok(())
}
pub(crate) fn to_payloads(&self) -> Vec<Payload> {
let mut payloads = Vec::with_capacity(
self.prefix_assertions.len()
+ self.bgpsec_assertions.len()
+ self.aspa_assertions.len(),
);
payloads.extend(
self.prefix_assertions
.iter()
.cloned()
.map(|assertion| Payload::RouteOrigin(assertion.into_route_origin())),
);
payloads.extend(
self.bgpsec_assertions
.iter()
.cloned()
.map(|assertion| Payload::RouterKey(assertion.into_router_key())),
);
payloads.extend(
self.aspa_assertions
.iter()
.cloned()
.map(|assertion| Payload::Aspa(assertion.into_aspa())),
);
payloads
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrefixFilter {
pub prefix: Option<IPAddressPrefix>,
pub asn: Option<Asn>,
pub comment: Option<String>,
}
impl PrefixFilter {
fn validate(&self) -> Result<(), SlurmError> {
if self.prefix.is_none() && self.asn.is_none() {
return Err(SlurmError::Invalid(
"prefixFilter must contain at least one of prefix or asn".to_string(),
));
}
Ok(())
}
fn matches(&self, route_origin: &RouteOrigin) -> bool {
let prefix_match = self
.prefix
.is_none_or(|filter_prefix| prefix_encompasses(&filter_prefix, route_origin.prefix()));
let asn_match = self.asn.is_none_or(|asn| asn == route_origin.asn());
prefix_match && asn_match
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BgpsecFilter {
pub asn: Option<Asn>,
pub ski: Option<Ski>,
pub comment: Option<String>,
}
impl BgpsecFilter {
fn validate(&self) -> Result<(), SlurmError> {
if self.asn.is_none() && self.ski.is_none() {
return Err(SlurmError::Invalid(
"bgpsecFilter must contain at least one of asn or SKI".to_string(),
));
}
Ok(())
}
fn matches(&self, router_key: &RouterKey) -> bool {
let asn_match = self.asn.is_none_or(|asn| asn == router_key.asn());
let ski_match = self.ski.is_none_or(|ski| ski == router_key.ski());
asn_match && ski_match
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AspaFilter {
pub customer_asn: Asn,
pub comment: Option<String>,
}
impl AspaFilter {
fn validate(&self) -> Result<(), SlurmError> {
if self.customer_asn.into_u32() == 0 {
return Err(SlurmError::Invalid(
"aspaFilter customerAsn must not be AS0".to_string(),
));
}
Ok(())
}
fn matches(&self, aspa: &Aspa) -> bool {
self.customer_asn == aspa.customer_asn()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrefixAssertion {
pub prefix: IPAddressPrefix,
pub asn: Asn,
pub max_prefix_length: Option<u8>,
pub comment: Option<String>,
}
impl PrefixAssertion {
fn validate(&self) -> Result<(), SlurmError> {
if self.asn.into_u32() == 0 {
return Err(SlurmError::Invalid(
"prefixAssertion asn must not be AS0".to_string(),
));
}
if let Some(max_prefix_length) = self.max_prefix_length {
let address_bits = prefix_address_bits(&self.prefix);
if max_prefix_length < self.prefix.prefix_length() {
return Err(SlurmError::Invalid(format!(
"prefixAssertion maxPrefixLength {} must be >= prefix length {}",
max_prefix_length,
self.prefix.prefix_length()
)));
}
if max_prefix_length > address_bits {
return Err(SlurmError::Invalid(format!(
"prefixAssertion maxPrefixLength {} exceeds address size {}",
max_prefix_length, address_bits
)));
}
}
Ok(())
}
fn into_route_origin(self) -> RouteOrigin {
RouteOrigin::new(
self.prefix,
self.max_prefix_length
.unwrap_or(self.prefix.prefix_length()),
self.asn,
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BgpsecAssertion {
pub asn: Asn,
pub ski: Ski,
pub router_public_key: Vec<u8>,
pub comment: Option<String>,
}
impl BgpsecAssertion {
fn validate(&self) -> Result<(), SlurmError> {
RouterKey::new(self.ski, self.asn, self.router_public_key.clone())
.validate()
.map_err(|err| SlurmError::Invalid(err.to_string()))
}
fn into_router_key(self) -> RouterKey {
RouterKey::new(self.ski, self.asn, self.router_public_key)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AspaAssertion {
pub customer_asn: Asn,
pub provider_asns: Vec<Asn>,
pub comment: Option<String>,
}
impl AspaAssertion {
fn validate(&self) -> Result<(), SlurmError> {
let providers = self
.provider_asns
.iter()
.map(|asn| asn.into_u32())
.collect::<Vec<_>>();
if providers.windows(2).any(|window| window[0] >= window[1]) {
return Err(SlurmError::Invalid(
"aspaAssertion providerAsns must be strictly increasing".to_string(),
));
}
let aspa = Aspa::new(self.customer_asn, self.provider_asns.clone());
aspa.validate_announcement()
.map_err(|err| SlurmError::Invalid(err.to_string()))?;
Ok(())
}
fn into_aspa(self) -> Aspa {
Aspa::new(self.customer_asn, self.provider_asns)
}
}
pub(crate) fn parse_ip_prefix(input: &str) -> Result<IPAddressPrefix, SlurmError> {
let (addr, prefix_length) = input
.split_once('/')
.ok_or_else(|| SlurmError::Invalid(format!("invalid prefix '{}'", input)))?;
let address = IpAddr::from_str(addr.trim())
.map_err(|err| SlurmError::Invalid(format!("invalid IP address '{}': {}", addr, err)))?;
let prefix_length = prefix_length.trim().parse::<u8>().map_err(|err| {
SlurmError::Invalid(format!(
"invalid prefix length '{}': {}",
prefix_length, err
))
})?;
match address {
IpAddr::V4(addr) => {
if prefix_length > 32 {
return Err(SlurmError::Invalid(format!(
"IPv4 prefix length {} exceeds 32",
prefix_length
)));
}
if !is_canonical_v4(addr, prefix_length) {
return Err(SlurmError::Invalid(format!(
"IPv4 prefix '{}' is not canonical",
input
)));
}
Ok(IPAddressPrefix::new(
IPAddress::from_ipv4(addr),
prefix_length,
))
}
IpAddr::V6(addr) => {
if prefix_length > 128 {
return Err(SlurmError::Invalid(format!(
"IPv6 prefix length {} exceeds 128",
prefix_length
)));
}
if !is_canonical_v6(addr, prefix_length) {
return Err(SlurmError::Invalid(format!(
"IPv6 prefix '{}' is not canonical",
input
)));
}
Ok(IPAddressPrefix::new(
IPAddress::from_ipv6(addr),
prefix_length,
))
}
}
}
pub(crate) fn prefix_address_bits(prefix: &IPAddressPrefix) -> u8 {
match prefix.address() {
IPAddress::V4(_) => 32,
IPAddress::V6(_) => 128,
}
}
pub(crate) fn prefix_encompasses(filter: &IPAddressPrefix, other: &IPAddressPrefix) -> bool {
if filter.afi() != other.afi() {
return false;
}
if filter.prefix_length() > other.prefix_length() {
return false;
}
match (filter.address(), other.address()) {
(IPAddress::V4(lhs), IPAddress::V4(rhs)) => {
prefix_match_v4(lhs, rhs, filter.prefix_length())
}
(IPAddress::V6(lhs), IPAddress::V6(rhs)) => {
prefix_match_v6(lhs, rhs, filter.prefix_length())
}
_ => false,
}
}
fn prefix_match_v4(lhs: Ipv4Addr, rhs: Ipv4Addr, prefix_length: u8) -> bool {
let mask = if prefix_length == 0 {
0
} else {
u32::MAX << (32 - prefix_length)
};
(u32::from(lhs) & mask) == (u32::from(rhs) & mask)
}
fn prefix_match_v6(lhs: Ipv6Addr, rhs: Ipv6Addr, prefix_length: u8) -> bool {
let mask = if prefix_length == 0 {
0
} else {
u128::MAX << (128 - prefix_length)
};
(u128::from(lhs) & mask) == (u128::from(rhs) & mask)
}
fn is_canonical_v4(addr: Ipv4Addr, prefix_length: u8) -> bool {
let mask = if prefix_length == 0 {
0
} else {
u32::MAX << (32 - prefix_length)
};
(u32::from(addr) & !mask) == 0
}
fn is_canonical_v6(addr: Ipv6Addr, prefix_length: u8) -> bool {
let mask = if prefix_length == 0 {
0
} else {
u128::MAX << (128 - prefix_length)
};
(u128::from(addr) & !mask) == 0
}

313
src/slurm/serde.rs Normal file
View File

@ -0,0 +1,313 @@
use std::io;
use base64::Engine;
use base64::engine::general_purpose::STANDARD_NO_PAD;
use serde::Deserialize;
use crate::data_model::resources::as_resources::Asn;
use crate::rtr::payload::Ski;
use crate::slurm::file::{SlurmError, SlurmFile, SlurmVersion};
use crate::slurm::policy::{
AspaAssertion, AspaFilter, BgpsecAssertion, BgpsecFilter, LocallyAddedAssertions,
PrefixAssertion, PrefixFilter, ValidationOutputFilters, parse_ip_prefix,
};
impl SlurmFile {
pub fn from_slice(input: &[u8]) -> Result<Self, SlurmError> {
let version = serde_json::from_slice::<SlurmVersionMarker>(input)?.slurm_version;
match version {
SlurmVersion::V1_U32 => {
let raw = serde_json::from_slice::<RawSlurmFileV1>(input)?;
Self::from_raw_v1(raw)
}
SlurmVersion::V2_U32 => {
let raw = serde_json::from_slice::<RawSlurmFileV2>(input)?;
Self::from_raw_v2(raw)
}
other => Err(SlurmError::Invalid(format!(
"unsupported slurmVersion {}, expected 1 or 2",
other
))),
}
}
pub fn from_reader(mut reader: impl io::Read) -> Result<Self, SlurmError> {
let mut bytes = Vec::new();
reader.read_to_end(&mut bytes)?;
Self::from_slice(&bytes)
}
fn from_raw_v1(raw: RawSlurmFileV1) -> Result<Self, SlurmError> {
Self::new(
SlurmVersion::V1,
ValidationOutputFilters {
prefix_filters: raw.validation_output_filters.prefix_filters,
bgpsec_filters: raw.validation_output_filters.bgpsec_filters,
aspa_filters: Vec::new(),
},
LocallyAddedAssertions {
prefix_assertions: raw.locally_added_assertions.prefix_assertions,
bgpsec_assertions: raw.locally_added_assertions.bgpsec_assertions,
aspa_assertions: Vec::new(),
},
)
}
fn from_raw_v2(raw: RawSlurmFileV2) -> Result<Self, SlurmError> {
Self::new(
SlurmVersion::V2,
ValidationOutputFilters {
prefix_filters: raw.validation_output_filters.prefix_filters,
bgpsec_filters: raw.validation_output_filters.bgpsec_filters,
aspa_filters: raw.validation_output_filters.aspa_filters,
},
LocallyAddedAssertions {
prefix_assertions: raw.locally_added_assertions.prefix_assertions,
bgpsec_assertions: raw.locally_added_assertions.bgpsec_assertions,
aspa_assertions: raw.locally_added_assertions.aspa_assertions,
},
)
}
}
#[derive(Deserialize)]
struct SlurmVersionMarker {
#[serde(rename = "slurmVersion")]
slurm_version: u32,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawSlurmFileV1 {
#[serde(rename = "slurmVersion")]
_slurm_version: u32,
#[serde(rename = "validationOutputFilters")]
validation_output_filters: RawValidationOutputFiltersV1,
#[serde(rename = "locallyAddedAssertions")]
locally_added_assertions: RawLocallyAddedAssertionsV1,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawSlurmFileV2 {
#[serde(rename = "slurmVersion")]
_slurm_version: u32,
#[serde(rename = "validationOutputFilters")]
validation_output_filters: RawValidationOutputFiltersV2,
#[serde(rename = "locallyAddedAssertions")]
locally_added_assertions: RawLocallyAddedAssertionsV2,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawValidationOutputFiltersV1 {
#[serde(rename = "prefixFilters")]
prefix_filters: Vec<PrefixFilter>,
#[serde(rename = "bgpsecFilters")]
bgpsec_filters: Vec<BgpsecFilter>,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawValidationOutputFiltersV2 {
#[serde(rename = "prefixFilters")]
prefix_filters: Vec<PrefixFilter>,
#[serde(rename = "bgpsecFilters")]
bgpsec_filters: Vec<BgpsecFilter>,
#[serde(rename = "aspaFilters")]
aspa_filters: Vec<AspaFilter>,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawLocallyAddedAssertionsV1 {
#[serde(rename = "prefixAssertions")]
prefix_assertions: Vec<PrefixAssertion>,
#[serde(rename = "bgpsecAssertions")]
bgpsec_assertions: Vec<BgpsecAssertion>,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawLocallyAddedAssertionsV2 {
#[serde(rename = "prefixAssertions")]
prefix_assertions: Vec<PrefixAssertion>,
#[serde(rename = "bgpsecAssertions")]
bgpsec_assertions: Vec<BgpsecAssertion>,
#[serde(rename = "aspaAssertions")]
aspa_assertions: Vec<AspaAssertion>,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawPrefixFilter {
prefix: Option<String>,
asn: Option<u32>,
comment: Option<String>,
}
impl<'de> Deserialize<'de> for PrefixFilter {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RawPrefixFilter::deserialize(deserializer)?;
Ok(Self {
prefix: raw
.prefix
.map(|prefix| parse_ip_prefix(&prefix))
.transpose()
.map_err(serde::de::Error::custom)?,
asn: raw.asn.map(Asn::from),
comment: raw.comment,
})
}
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawBgpsecFilter {
asn: Option<u32>,
#[serde(rename = "SKI")]
ski: Option<String>,
comment: Option<String>,
}
impl<'de> Deserialize<'de> for BgpsecFilter {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RawBgpsecFilter::deserialize(deserializer)?;
Ok(Self {
asn: raw.asn.map(Asn::from),
ski: raw
.ski
.map(|ski| decode_ski(&ski))
.transpose()
.map_err(serde::de::Error::custom)?,
comment: raw.comment,
})
}
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawAspaFilter {
#[serde(rename = "customerAsn")]
customer_asn: u32,
comment: Option<String>,
}
impl<'de> Deserialize<'de> for AspaFilter {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RawAspaFilter::deserialize(deserializer)?;
Ok(Self {
customer_asn: Asn::from(raw.customer_asn),
comment: raw.comment,
})
}
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawPrefixAssertion {
prefix: String,
asn: u32,
#[serde(rename = "maxPrefixLength")]
max_prefix_length: Option<u8>,
comment: Option<String>,
}
impl<'de> Deserialize<'de> for PrefixAssertion {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RawPrefixAssertion::deserialize(deserializer)?;
Ok(Self {
prefix: parse_ip_prefix(&raw.prefix).map_err(serde::de::Error::custom)?,
asn: Asn::from(raw.asn),
max_prefix_length: raw.max_prefix_length,
comment: raw.comment,
})
}
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawBgpsecAssertion {
asn: u32,
#[serde(rename = "SKI")]
ski: String,
#[serde(rename = "routerPublicKey")]
router_public_key: String,
comment: Option<String>,
}
impl<'de> Deserialize<'de> for BgpsecAssertion {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RawBgpsecAssertion::deserialize(deserializer)?;
Ok(Self {
asn: Asn::from(raw.asn),
ski: decode_ski(&raw.ski).map_err(serde::de::Error::custom)?,
router_public_key: decode_router_public_key(&raw.router_public_key)
.map_err(serde::de::Error::custom)?,
comment: raw.comment,
})
}
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct RawAspaAssertion {
#[serde(rename = "customerAsn")]
customer_asn: u32,
#[serde(rename = "providerAsns")]
provider_asns: Vec<u32>,
comment: Option<String>,
}
impl<'de> Deserialize<'de> for AspaAssertion {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = RawAspaAssertion::deserialize(deserializer)?;
Ok(Self {
customer_asn: Asn::from(raw.customer_asn),
provider_asns: raw.provider_asns.into_iter().map(Asn::from).collect(),
comment: raw.comment,
})
}
}
fn decode_ski(input: &str) -> Result<Ski, SlurmError> {
let bytes = hex::decode(input)
.map_err(|err| SlurmError::Invalid(format!("invalid SKI '{}': {}", input, err)))?;
if bytes.len() != 20 {
return Err(SlurmError::Invalid(format!(
"SKI must be exactly 20 bytes, got {}",
bytes.len()
)));
}
let mut ski = [0u8; 20];
ski.copy_from_slice(&bytes);
Ok(Ski::from_bytes(ski))
}
fn decode_router_public_key(input: &str) -> Result<Vec<u8>, SlurmError> {
STANDARD_NO_PAD.decode(input).map_err(|err| {
SlurmError::Invalid(format!(
"invalid routerPublicKey base64 '{}': {}",
input, err
))
})
}

View File

@ -1,80 +0,0 @@
use std::io;
use crate::data_model::resources::as_resources::Asn;
#[derive(Debug, thiserror::Error)]
pub enum SlurmError {
#[error("Read slurm from reader error")]
SlurmFromReader(),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SlurmFile {
pub version: u32,
pub validation_output_filters: ValidationOutputFilters,
pub locally_added_assertions: LocallyAddedAssertions,
}
impl SlurmFile {
pub fn new(filters: ValidationOutputFilters,
assertions: LocallyAddedAssertions,) -> Self {
let version = 1;
SlurmFile {
version,
validation_output_filters: filters,
locally_added_assertions: assertions,
}
}
// pub fn from_reader(reader: impl io::Read)-> Result<Self, SlurmError> {
//
// }
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ValidationOutputFilters {
pub prefix_filters: Vec<PrefixFilter>,
pub bgpset_filters: Vec<BgpsecFilter>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Comment(String);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrefixFilter {
pub prefix: String,
pub asn: Asn,
pub comment: Option<Comment>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BgpsecFilter {
pub asn: Asn,
pub ski: u8,
pub comment: Option<Comment>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LocallyAddedAssertions {
pub prefix_assertions: Vec<PrefixAssertion>,
pub bgpsec_assertions: Vec<BgpsecAssertion>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PrefixAssertion {
pub prefix: String,
pub asn: Asn,
pub max_prefix_length: u8,
pub comment: Option<Comment>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BgpsecAssertion {
pub asn: Asn,
pub ski: u8,
pub router_public_key: u8,
pub comment: Option<Comment>,
}

View File

@ -2,7 +2,7 @@ use std::fs;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use anyhow::{anyhow, Context, Result}; use anyhow::{Context, Result, anyhow};
use der_parser::ber::{BerObject, BerObjectContent}; use der_parser::ber::{BerObject, BerObjectContent};
use der_parser::der::parse_der; use der_parser::der::parse_der;
@ -29,8 +29,8 @@ pub struct CcrPayloadConversion {
pub fn load_ccr_snapshot_from_file(path: impl AsRef<Path>) -> Result<ParsedCcrSnapshot> { pub fn load_ccr_snapshot_from_file(path: impl AsRef<Path>) -> Result<ParsedCcrSnapshot> {
let path = path.as_ref(); let path = path.as_ref();
let bytes = fs::read(path) let bytes =
.with_context(|| format!("failed to read CCR file: {}", path.display()))?; fs::read(path).with_context(|| format!("failed to read CCR file: {}", path.display()))?;
parse_ccr_bytes(&bytes).with_context(|| format!("failed to parse CCR file: {}", path.display())) parse_ccr_bytes(&bytes).with_context(|| format!("failed to parse CCR file: {}", path.display()))
} }
@ -64,9 +64,10 @@ pub fn find_latest_ccr_file(dir: impl AsRef<Path>) -> Result<PathBuf> {
continue; continue;
} }
if latest.as_ref().is_none_or(|current| { if latest
file_name_key(&path) > file_name_key(current) .as_ref()
}) { .is_none_or(|current| file_name_key(&path) > file_name_key(current))
{
latest = Some(path); latest = Some(path);
} }
} }
@ -250,10 +251,7 @@ fn parse_vaps(field: &BerObject<'_>) -> Result<Vec<ParsedAspa>> {
Ok(vaps) Ok(vaps)
} }
fn parse_roa_address( fn parse_roa_address(address_family: &[u8], items: &[BerObject<'_>]) -> Result<(IpAddr, u8, u8)> {
address_family: &[u8],
items: &[BerObject<'_>],
) -> Result<(IpAddr, u8, u8)> {
let address = items let address = items
.first() .first()
.ok_or_else(|| anyhow!("ROAIPAddress missing address field"))?; .ok_or_else(|| anyhow!("ROAIPAddress missing address field"))?;
@ -275,8 +273,7 @@ fn parse_roa_address(
let max_len = match items.get(1) { let max_len = match items.get(1) {
Some(value) => { Some(value) => {
let max_len = as_u32(value, "ROAIPAddress.maxLength")?; let max_len = as_u32(value, "ROAIPAddress.maxLength")?;
u8::try_from(max_len) u8::try_from(max_len).map_err(|_| anyhow!("maxLength {max_len} does not fit in u8"))?
.map_err(|_| anyhow!("maxLength {max_len} does not fit in u8"))?
} }
None => prefix_len, None => prefix_len,
}; };
@ -328,10 +325,7 @@ fn decode_context_wrapped_sequence<'a>(obj: &'a BerObject<'a>) -> Result<BerObje
let (rem, inner) = parse_der(any.data) let (rem, inner) = parse_der(any.data)
.map_err(|err| anyhow!("failed to parse encapsulated DER: {err}"))?; .map_err(|err| anyhow!("failed to parse encapsulated DER: {err}"))?;
if !rem.is_empty() { if !rem.is_empty() {
return Err(anyhow!( return Err(anyhow!("encapsulated DER has {} trailing bytes", rem.len()));
"encapsulated DER has {} trailing bytes",
rem.len()
));
} }
Ok(inner) Ok(inner)
} }

2
src/source/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod ccr;
pub mod pipeline;

138
src/source/pipeline.rs Normal file
View File

@ -0,0 +1,138 @@
use anyhow::{Result, anyhow};
use std::path::PathBuf;
use tracing::{info, warn};
use crate::rtr::payload::Payload;
use crate::slurm::file::SlurmFile;
use crate::source::ccr::{
find_latest_ccr_file, load_ccr_payloads_from_file_with_options, load_ccr_snapshot_from_file,
};
#[derive(Debug, Clone)]
pub struct PayloadLoadConfig {
pub ccr_dir: String,
pub slurm_dir: Option<String>,
pub strict_ccr_validation: bool,
}
pub fn load_payloads_from_latest_sources(config: &PayloadLoadConfig) -> Result<Vec<Payload>> {
let payloads = load_payloads_from_latest_ccr(&config.ccr_dir, config.strict_ccr_validation)?;
match config.slurm_dir.as_deref() {
Some(dir) => apply_slurm_to_payloads_from_dir(dir, payloads),
None => Ok(payloads),
}
}
fn load_payloads_from_latest_ccr(
ccr_dir: &str,
strict_ccr_validation: bool,
) -> Result<Vec<Payload>> {
let latest = find_latest_ccr_file(ccr_dir)?;
let snapshot = load_ccr_snapshot_from_file(&latest)?;
let vrp_count = snapshot.vrps.len();
let vap_count = snapshot.vaps.len();
let produced_at = snapshot.produced_at.clone();
let conversion = load_ccr_payloads_from_file_with_options(&latest, strict_ccr_validation)?;
let payloads = conversion.payloads;
if !conversion.invalid_vrps.is_empty() {
warn!(
"CCR load skipped invalid VRPs: file={}, skipped={}, samples={:?}",
latest.display(),
conversion.invalid_vrps.len(),
sample_messages(&conversion.invalid_vrps)
);
}
if !conversion.invalid_vaps.is_empty() {
warn!(
"CCR load skipped invalid VAPs/ASPAs: file={}, skipped={}, samples={:?}",
latest.display(),
conversion.invalid_vaps.len(),
sample_messages(&conversion.invalid_vaps)
);
}
info!(
"loaded latest CCR snapshot: file={}, produced_at={:?}, vrp_count={}, vap_count={}, payload_count={}, strict_ccr_validation={}",
latest.display(),
produced_at,
vrp_count,
vap_count,
payloads.len(),
strict_ccr_validation
);
Ok(payloads)
}
fn apply_slurm_to_payloads_from_dir(
slurm_dir: &str,
payloads: Vec<Payload>,
) -> Result<Vec<Payload>> {
let files = read_slurm_files(slurm_dir)?;
let file_count = files.len();
let file_names = files
.iter()
.map(|(name, _)| name.clone())
.collect::<Vec<_>>();
let slurm = SlurmFile::merge_named(files)
.map_err(|err| anyhow!("failed to merge SLURM files from '{}': {}", slurm_dir, err))?;
let input_count = payloads.len();
let filtered = slurm.apply(&payloads);
let output_count = filtered.len();
info!(
"applied SLURM policy set: slurm_dir={}, file_count={}, files={:?}, merged_slurm_version={}, input_payload_count={}, output_payload_count={}",
slurm_dir,
file_count,
file_names,
slurm.version().as_u32(),
input_count,
output_count
);
Ok(filtered)
}
fn read_slurm_files(slurm_dir: &str) -> Result<Vec<(String, SlurmFile)>> {
let mut paths = std::fs::read_dir(slurm_dir)
.map_err(|err| anyhow!("failed to read SLURM directory '{}': {}", slurm_dir, err))?
.filter_map(|entry| entry.ok())
.map(|entry| entry.path())
.filter(|path| path.is_file())
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("slurm"))
.collect::<Vec<PathBuf>>();
paths.sort_by_key(|path| {
path.file_name()
.and_then(|name| name.to_str())
.map(|name| name.to_ascii_lowercase())
.unwrap_or_default()
});
if paths.is_empty() {
return Err(anyhow!(
"SLURM directory '{}' does not contain .slurm files",
slurm_dir
));
}
paths
.into_iter()
.map(|path| {
let name = path.to_string_lossy().to_string();
let file = std::fs::File::open(&path)
.map_err(|err| anyhow!("failed to open SLURM file '{}': {}", name, err))?;
let slurm = SlurmFile::from_reader(file)
.map_err(|err| anyhow!("failed to parse SLURM file '{}': {}", name, err))?;
Ok((name, slurm))
})
.collect()
}
fn sample_messages(messages: &[String]) -> Vec<&str> {
messages.iter().take(3).map(String::as_str).collect()
}

View File

@ -1,7 +1,7 @@
use std::fmt::Write; use std::fmt::Write;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
use serde_json::{json, Value}; use serde_json::{Value, json};
use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix}; use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
use rpki::rtr::cache::SerialResult; use rpki::rtr::cache::SerialResult;
@ -14,7 +14,9 @@ pub struct RtrDebugDumper {
impl RtrDebugDumper { impl RtrDebugDumper {
pub fn new() -> Self { pub fn new() -> Self {
Self { entries: Vec::new() } Self {
entries: Vec::new(),
}
} }
pub fn push<T: serde::Serialize>(&mut self, pdu: u8, body: &T) { pub fn push<T: serde::Serialize>(&mut self, pdu: u8, body: &T) {
@ -150,15 +152,7 @@ pub fn v6_prefix(addr: Ipv6Addr, prefix_len: u8) -> IPAddressPrefix {
} }
} }
pub fn v4_origin( pub fn v4_origin(a: u8, b: u8, c: u8, d: u8, prefix_len: u8, max_len: u8, asn: u32) -> RouteOrigin {
a: u8,
b: u8,
c: u8,
d: u8,
prefix_len: u8,
max_len: u8,
asn: u32,
) -> RouteOrigin {
let prefix = v4_prefix(a, b, c, d, prefix_len); let prefix = v4_prefix(a, b, c, d, prefix_len);
RouteOrigin::new(prefix, max_len, asn.into()) RouteOrigin::new(prefix, max_len, asn.into())
} }
@ -238,7 +232,11 @@ pub fn serial_result_to_string(result: &SerialResult) -> String {
} }
pub fn print_serial_result(label: &str, result: &SerialResult) { pub fn print_serial_result(label: &str, result: &SerialResult) {
println!("\n===== {} =====\n{}\n", label, serial_result_to_string(result)); println!(
"\n===== {} =====\n{}\n",
label,
serial_result_to_string(result)
);
} }
pub fn bytes_to_hex(bytes: &[u8]) -> String { pub fn bytes_to_hex(bytes: &[u8]) -> String {
@ -290,12 +288,8 @@ pub fn snapshot_hashes_to_string(snapshot: &rpki::rtr::cache::Snapshot) -> Strin
pub fn serial_result_detail_to_string(result: &rpki::rtr::cache::SerialResult) -> String { pub fn serial_result_detail_to_string(result: &rpki::rtr::cache::SerialResult) -> String {
match result { match result {
rpki::rtr::cache::SerialResult::UpToDate => { rpki::rtr::cache::SerialResult::UpToDate => " result: UpToDate\n".to_string(),
" result: UpToDate\n".to_string() rpki::rtr::cache::SerialResult::ResetRequired => " result: ResetRequired\n".to_string(),
}
rpki::rtr::cache::SerialResult::ResetRequired => {
" result: ResetRequired\n".to_string()
}
rpki::rtr::cache::SerialResult::Delta(delta) => { rpki::rtr::cache::SerialResult::Delta(delta) => {
let mut out = String::new(); let mut out = String::new();
let _ = writeln!(&mut out, " result: Delta"); let _ = writeln!(&mut out, " result: Delta");

View File

@ -1,4 +1,4 @@
mod common; mod common;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr};
@ -12,8 +12,7 @@ use common::test_helper::{
use rpki::data_model::resources::as_resources::Asn; use rpki::data_model::resources::as_resources::Asn;
use rpki::rtr::cache::{ use rpki::rtr::cache::{
CacheAvailability, Delta, RtrCacheBuilder, SerialResult, SessionIds, Snapshot, CacheAvailability, Delta, RtrCacheBuilder, SerialResult, SessionIds, Snapshot,
validate_payload_updates_for_rtr, validate_payload_updates_for_rtr, validate_payloads_for_rtr,
validate_payloads_for_rtr,
}; };
use rpki::rtr::payload::{Aspa, Payload, RouterKey, Ski, Timing}; use rpki::rtr::payload::{Aspa, Payload, RouterKey, Ski, Timing};
use rpki::rtr::store::RtrStore; use rpki::rtr::store::RtrStore;
@ -40,7 +39,11 @@ fn deltas_window_to_string(deltas: &VecDeque<Arc<Delta>>) -> String {
out out
} }
fn get_deltas_since_input_to_string(cache_session_id: u16, cache_serial: u32, client_serial: u32) -> String { fn get_deltas_since_input_to_string(
cache_session_id: u16,
cache_serial: u32,
client_serial: u32,
) -> String {
format!( format!(
"cache.session_id: {}\ncache.serial: {}\nclient_serial: {}\n", "cache.session_id: {}\ncache.serial: {}\nclient_serial: {}\n",
cache_session_id, cache_serial, client_serial cache_session_id, cache_serial, client_serial
@ -118,7 +121,9 @@ async fn init_keeps_cache_running_when_file_loader_returns_no_data() {
let store = RtrStore::open(dir.path()).unwrap(); let store = RtrStore::open(dir.path()).unwrap();
let cache = rpki::rtr::cache::RtrCache::default() let cache = rpki::rtr::cache::RtrCache::default()
.init(&store, 16, false, Timing::new(600, 600, 7200), || Ok(vec![])) .init(&store, 16, false, Timing::new(600, 600, 7200), || {
Ok(vec![])
})
.unwrap(); .unwrap();
assert!(!cache.is_data_available()); assert!(!cache.is_data_available());
@ -144,12 +149,16 @@ async fn init_restores_wraparound_delta_window_from_store() {
); );
let d_zero = Delta::new( let d_zero = Delta::new(
0, 0,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], vec![Payload::RouteOrigin(v4_origin(
198, 51, 100, 0, 24, 24, 64497,
))],
vec![], vec![],
); );
let d_one = Delta::new( let d_one = Delta::new(
1, 1,
vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], vec![Payload::RouteOrigin(v4_origin(
203, 0, 113, 0, 24, 24, 64498,
))],
vec![], vec![],
); );
@ -188,7 +197,9 @@ async fn init_restores_wraparound_delta_window_from_store() {
.unwrap(); .unwrap();
let cache = rpki::rtr::cache::RtrCache::default() let cache = rpki::rtr::cache::RtrCache::default()
.init(&store, 16, false, Timing::new(600, 600, 7200), || Ok(Vec::new())) .init(&store, 16, false, Timing::new(600, 600, 7200), || {
Ok(Vec::new())
})
.unwrap(); .unwrap();
match cache.get_deltas_since(u32::MAX.wrapping_sub(1)) { match cache.get_deltas_since(u32::MAX.wrapping_sub(1)) {
@ -205,8 +216,8 @@ async fn update_prunes_delta_window_when_cumulative_delta_size_reaches_snapshot_
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap(); let store = RtrStore::open(dir.path()).unwrap();
let valid_spki = vec![ let valid_spki = vec![
0x30, 0x13, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x30, 0x13, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01,
0x01, 0x01, 0x05, 0x00, 0x03, 0x02, 0x00, 0x00, 0x05, 0x00, 0x03, 0x02, 0x00, 0x00,
]; ];
let initial_snapshot = Snapshot::from_payloads(vec![Payload::RouterKey(RouterKey::new( let initial_snapshot = Snapshot::from_payloads(vec![Payload::RouterKey(RouterKey::new(
@ -226,11 +237,14 @@ async fn update_prunes_delta_window_when_cumulative_delta_size_reaches_snapshot_
.build(); .build();
cache cache
.update(vec![Payload::RouterKey(RouterKey::new( .update(
vec![Payload::RouterKey(RouterKey::new(
Ski::from_bytes([1u8; 20]), Ski::from_bytes([1u8; 20]),
Asn::from(64496u32), Asn::from(64496u32),
valid_spki, valid_spki,
))], &store) ))],
&store,
)
.unwrap(); .unwrap();
match cache.get_deltas_since(1) { match cache.get_deltas_since(1) {
@ -414,7 +428,10 @@ fn delta_new_sorts_announced_descending_and_withdrawn_ascending() {
let w0 = as_v4_route_origin(&delta.withdrawn()[0]); let w0 = as_v4_route_origin(&delta.withdrawn()[0]);
let w1 = as_v4_route_origin(&delta.withdrawn()[1]); let w1 = as_v4_route_origin(&delta.withdrawn()[1]);
assert_eq!(w0.prefix().address.to_ipv4(), Some(Ipv4Addr::new(10, 0, 0, 0))); assert_eq!(
w0.prefix().address.to_ipv4(),
Some(Ipv4Addr::new(10, 0, 0, 0))
);
assert_eq!( assert_eq!(
w1.prefix().address.to_ipv4(), w1.prefix().address.to_ipv4(),
Some(Ipv4Addr::new(203, 0, 113, 0)) Some(Ipv4Addr::new(203, 0, 113, 0))
@ -433,7 +450,8 @@ fn get_deltas_since_returns_up_to_date_when_client_serial_matches_current() {
let result = cache.get_deltas_since(100); let result = cache.get_deltas_since(100);
let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100); let input =
get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 100);
let output = serial_result_detail_to_string(&result); let output = serial_result_detail_to_string(&result);
test_report( test_report(
@ -458,7 +476,9 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old() {
)); ));
let d2 = Arc::new(Delta::new( let d2 = Arc::new(Delta::new(
102, 102,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], vec![Payload::RouteOrigin(v4_origin(
198, 51, 100, 0, 24, 24, 64497,
))],
vec![], vec![],
)); ));
@ -506,12 +526,16 @@ fn get_deltas_since_returns_minimal_merged_delta() {
)); ));
let d2 = Arc::new(Delta::new( let d2 = Arc::new(Delta::new(
102, 102,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], vec![Payload::RouteOrigin(v4_origin(
198, 51, 100, 0, 24, 24, 64497,
))],
vec![], vec![],
)); ));
let d3 = Arc::new(Delta::new( let d3 = Arc::new(Delta::new(
103, 103,
vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], vec![Payload::RouteOrigin(v4_origin(
203, 0, 113, 0, 24, 24, 64498,
))],
vec![], vec![],
)); ));
@ -584,7 +608,8 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future() {
let result = cache.get_deltas_since(101); let result = cache.get_deltas_since(101);
let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 101); let input =
get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 101);
let output = serial_result_detail_to_string(&result); let output = serial_result_detail_to_string(&result);
test_report( test_report(
@ -610,11 +635,7 @@ fn get_deltas_since_supports_incremental_updates_across_serial_wraparound() {
vec![Payload::RouteOrigin(a.clone())], vec![Payload::RouteOrigin(a.clone())],
vec![], vec![],
)); ));
let d_zero = Arc::new(Delta::new( let d_zero = Arc::new(Delta::new(0, vec![Payload::RouteOrigin(b.clone())], vec![]));
0,
vec![Payload::RouteOrigin(b.clone())],
vec![],
));
let mut deltas = VecDeque::new(); let mut deltas = VecDeque::new();
deltas.push_back(d_max); deltas.push_back(d_max);
@ -637,7 +658,11 @@ fn get_deltas_since_supports_incremental_updates_across_serial_wraparound() {
let input = format!( let input = format!(
"{}delta_window:\n{}", "{}delta_window:\n{}",
get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), u32::MAX.wrapping_sub(1)), get_deltas_since_input_to_string(
cache.session_id_for_version(1),
cache.serial(),
u32::MAX.wrapping_sub(1)
),
indent_block(&deltas_window_to_string(&deltas), 2), indent_block(&deltas_window_to_string(&deltas), 2),
); );
let output = serial_result_detail_to_string(&result); let output = serial_result_detail_to_string(&result);
@ -678,12 +703,16 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old_across_
)); ));
let d_zero = Arc::new(Delta::new( let d_zero = Arc::new(Delta::new(
0, 0,
vec![Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497))], vec![Payload::RouteOrigin(v4_origin(
198, 51, 100, 0, 24, 24, 64497,
))],
vec![], vec![],
)); ));
let d_one = Arc::new(Delta::new( let d_one = Arc::new(Delta::new(
1, 1,
vec![Payload::RouteOrigin(v4_origin(203, 0, 113, 0, 24, 24, 64498))], vec![Payload::RouteOrigin(v4_origin(
203, 0, 113, 0, 24, 24, 64498,
))],
vec![], vec![],
)); ));
@ -709,7 +738,11 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_too_old_across_
let input = format!( let input = format!(
"{}delta_window:\n{}", "{}delta_window:\n{}",
get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), client_serial), get_deltas_since_input_to_string(
cache.session_id_for_version(1),
cache.serial(),
client_serial
),
indent_block(&deltas_window_to_string(&deltas), 2), indent_block(&deltas_window_to_string(&deltas), 2),
); );
let output = serial_result_detail_to_string(&result); let output = serial_result_detail_to_string(&result);
@ -737,7 +770,8 @@ fn get_deltas_since_returns_reset_required_when_client_serial_is_in_future_acros
let result = cache.get_deltas_since(0); let result = cache.get_deltas_since(0);
let input = get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 0); let input =
get_deltas_since_input_to_string(cache.session_id_for_version(1), cache.serial(), 0);
let output = serial_result_detail_to_string(&result); let output = serial_result_detail_to_string(&result);
test_report( test_report(
@ -776,10 +810,7 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let store = RtrStore::open(dir.path()).unwrap(); let store = RtrStore::open(dir.path()).unwrap();
let new_payloads = vec![ let new_payloads = vec![Payload::RouteOrigin(old_b), Payload::RouteOrigin(old_a)];
Payload::RouteOrigin(old_b),
Payload::RouteOrigin(old_a),
];
cache.update(new_payloads.clone(), &store).unwrap(); cache.update(new_payloads.clone(), &store).unwrap();
@ -795,7 +826,10 @@ async fn update_no_change_keeps_serial_and_produces_no_delta() {
let output = format!( let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}",
cache.serial(), cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2), indent_block(
&snapshot_hashes_and_sorted_view_to_string(&current_snapshot),
2
),
indent_block(&serial_result_detail_to_string(&result), 2), indent_block(&serial_result_detail_to_string(&result), 2),
); );
@ -854,7 +888,10 @@ async fn update_add_only_increments_serial_and_generates_announced_delta() {
let output = format!( let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}",
cache.serial(), cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2), indent_block(
&snapshot_hashes_and_sorted_view_to_string(&current_snapshot),
2
),
indent_block(&serial_result_detail_to_string(&result), 2), indent_block(&serial_result_detail_to_string(&result), 2),
); );
@ -921,7 +958,10 @@ async fn update_remove_only_increments_serial_and_generates_withdrawn_delta() {
let output = format!( let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}",
cache.serial(), cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2), indent_block(
&snapshot_hashes_and_sorted_view_to_string(&current_snapshot),
2
),
indent_block(&serial_result_detail_to_string(&result), 2), indent_block(&serial_result_detail_to_string(&result), 2),
); );
@ -997,7 +1037,10 @@ async fn update_add_and_remove_increments_serial_and_generates_both_sides() {
let output = format!( let output = format!(
"cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}", "cache.serial_after_update: {}\ncurrent_snapshot:\n{}get_deltas_since(100):\n{}",
cache.serial(), cache.serial(),
indent_block(&snapshot_hashes_and_sorted_view_to_string(&current_snapshot), 2), indent_block(
&snapshot_hashes_and_sorted_view_to_string(&current_snapshot),
2
),
indent_block(&serial_result_detail_to_string(&result), 2), indent_block(&serial_result_detail_to_string(&result), 2),
); );
@ -1282,10 +1325,7 @@ fn get_deltas_since_merges_multiple_deltas_to_final_minimal_view() {
#[test] #[test]
fn snapshot_from_payloads_unions_aspas_by_customer() { fn snapshot_from_payloads_unions_aspas_by_customer() {
let first = Payload::Aspa(Aspa::new( let first = Payload::Aspa(Aspa::new(Asn::from(64496u32), vec![Asn::from(64497u32)]));
Asn::from(64496u32),
vec![Asn::from(64497u32)],
));
let second = Payload::Aspa(Aspa::new( let second = Payload::Aspa(Aspa::new(
Asn::from(64496u32), Asn::from(64496u32),
vec![Asn::from(64498u32), Asn::from(64497u32)], vec![Asn::from(64498u32), Asn::from(64497u32)],
@ -1369,31 +1409,21 @@ fn validate_payloads_for_rtr_rejects_unsorted_snapshot_payloads() {
let high = Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497)); let high = Payload::RouteOrigin(v4_origin(198, 51, 100, 0, 24, 24, 64497));
let err = validate_payloads_for_rtr(&[low, high], true).unwrap_err(); let err = validate_payloads_for_rtr(&[low, high], true).unwrap_err();
assert!(err assert!(err.to_string().contains("RTR payload ordering violation"));
.to_string()
.contains("RTR payload ordering violation"));
} }
#[test] #[test]
fn validate_payload_updates_for_rtr_rejects_unsorted_aspa_updates() { fn validate_payload_updates_for_rtr_rejects_unsorted_aspa_updates() {
let withdraw = ( let withdraw = (
false, false,
Payload::Aspa(Aspa::new( Payload::Aspa(Aspa::new(Asn::from(64497u32), vec![Asn::from(64500u32)])),
Asn::from(64497u32),
vec![Asn::from(64500u32)],
)),
); );
let announce = ( let announce = (
true, true,
Payload::Aspa(Aspa::new( Payload::Aspa(Aspa::new(Asn::from(64496u32), vec![Asn::from(64499u32)])),
Asn::from(64496u32),
vec![Asn::from(64499u32)],
)),
); );
let err = validate_payload_updates_for_rtr(&[withdraw, announce]).unwrap_err(); let err = validate_payload_updates_for_rtr(&[withdraw, announce]).unwrap_err();
assert!(err.to_string().contains("withdraw ASPA")); assert!(err.to_string().contains("withdraw ASPA"));
assert!(err.to_string().contains("announce ASPA")); assert!(err.to_string().contains("announce ASPA"));
} }

View File

@ -1,14 +1,10 @@
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use rpki::rtr::ccr::{
ParsedCcrSnapshot,
find_latest_ccr_file,
load_ccr_snapshot_from_file,
snapshot_to_payloads_with_options,
};
use rpki::rtr::loader::{ParsedAspa, ParsedVrp}; use rpki::rtr::loader::{ParsedAspa, ParsedVrp};
use tempfile::tempdir; use tempfile::tempdir;
use rpki::source::ccr::{find_latest_ccr_file, load_ccr_snapshot_from_file,
snapshot_to_payloads_with_options, ParsedCcrSnapshot};
fn fixture_path(name: &str) -> PathBuf { fn fixture_path(name: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("data").join(name) PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("data").join(name)

311
tests/test_slurm.rs Normal file
View File

@ -0,0 +1,311 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD_NO_PAD;
use rpki::data_model::resources::as_resources::Asn;
use rpki::data_model::resources::ip_resources::{IPAddress, IPAddressPrefix};
use rpki::rtr::payload::{Aspa, Payload, RouteOrigin, RouterKey, Ski};
use rpki::slurm::file::{SlurmFile, SlurmVersion};
fn sample_spki() -> Vec<u8> {
vec![
0x30, 0x13, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01,
0x05, 0x00, 0x03, 0x02, 0x00, 0x00,
]
}
fn sample_ski() -> [u8; 20] {
[
0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee,
0xff, 0x10, 0x20, 0x30, 0x40,
]
}
#[test]
fn parses_rfc8416_v1_slurm() {
let ski_hex = hex::encode(sample_ski());
let router_public_key = STANDARD_NO_PAD.encode(sample_spki());
let json = format!(
r#"{{
"slurmVersion": 1,
"validationOutputFilters": {{
"prefixFilters": [
{{ "prefix": "192.0.2.0/24", "asn": 64496, "comment": "drop roa" }}
],
"bgpsecFilters": [
{{ "asn": 64497, "SKI": "{ski_hex}" }}
]
}},
"locallyAddedAssertions": {{
"prefixAssertions": [
{{ "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 }}
],
"bgpsecAssertions": [
{{ "asn": 64501, "SKI": "{ski_hex}", "routerPublicKey": "{router_public_key}" }}
]
}}
}}"#
);
let slurm = SlurmFile::from_slice(json.as_bytes()).unwrap();
assert_eq!(slurm.version(), SlurmVersion::V1);
assert_eq!(slurm.validation_output_filters().prefix_filters.len(), 1);
assert_eq!(slurm.validation_output_filters().bgpsec_filters.len(), 1);
assert!(slurm.validation_output_filters().aspa_filters.is_empty());
assert_eq!(slurm.locally_added_assertions().prefix_assertions.len(), 1);
assert_eq!(slurm.locally_added_assertions().bgpsec_assertions.len(), 1);
assert!(slurm.locally_added_assertions().aspa_assertions.is_empty());
}
#[test]
fn parses_v2_slurm_with_aspa_extensions() {
let json = r#"{
"slurmVersion": 2,
"validationOutputFilters": {
"prefixFilters": [],
"bgpsecFilters": [],
"aspaFilters": [
{ "customerAsn": 64496 }
]
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": [],
"aspaAssertions": [
{ "customerAsn": 64510, "providerAsns": [64511, 64512] }
]
}
}"#;
let slurm = SlurmFile::from_slice(json.as_bytes()).unwrap();
assert_eq!(slurm.version(), SlurmVersion::V2);
assert_eq!(slurm.validation_output_filters().aspa_filters.len(), 1);
assert_eq!(slurm.locally_added_assertions().aspa_assertions.len(), 1);
}
#[test]
fn rejects_v1_file_with_aspa_members() {
let json = r#"{
"slurmVersion": 1,
"validationOutputFilters": {
"prefixFilters": [],
"bgpsecFilters": [],
"aspaFilters": []
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": []
}
}"#;
let err = SlurmFile::from_slice(json.as_bytes()).unwrap_err();
assert!(err.to_string().contains("unknown field"));
}
#[test]
fn rejects_non_canonical_prefixes_and_unsorted_aspa_providers() {
let non_canonical = r#"{
"slurmVersion": 1,
"validationOutputFilters": {
"prefixFilters": [
{ "prefix": "192.0.2.1/24" }
],
"bgpsecFilters": []
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": []
}
}"#;
let non_canonical_err = SlurmFile::from_slice(non_canonical.as_bytes()).unwrap_err();
assert!(non_canonical_err.to_string().contains("not canonical"));
let unsorted_aspa = r#"{
"slurmVersion": 2,
"validationOutputFilters": {
"prefixFilters": [],
"bgpsecFilters": [],
"aspaFilters": []
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": [],
"aspaAssertions": [
{ "customerAsn": 64500, "providerAsns": [64502, 64501] }
]
}
}"#;
let aspa_err = SlurmFile::from_slice(unsorted_aspa.as_bytes()).unwrap_err();
assert!(aspa_err.to_string().contains("strictly increasing"));
}
#[test]
fn applies_filters_before_assertions_and_excludes_duplicates() {
let ski = Ski::from_bytes(sample_ski());
let spki = sample_spki();
let spki_b64 = STANDARD_NO_PAD.encode(&spki);
let ski_hex = hex::encode(sample_ski());
let json = format!(
r#"{{
"slurmVersion": 2,
"validationOutputFilters": {{
"prefixFilters": [
{{ "prefix": "192.0.2.0/24", "asn": 64496 }}
],
"bgpsecFilters": [
{{ "SKI": "{ski_hex}" }}
],
"aspaFilters": [
{{ "customerAsn": 64496 }}
]
}},
"locallyAddedAssertions": {{
"prefixAssertions": [
{{ "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 }},
{{ "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 }}
],
"bgpsecAssertions": [
{{ "asn": 64501, "SKI": "{ski_hex}", "routerPublicKey": "{spki_b64}" }}
],
"aspaAssertions": [
{{ "customerAsn": 64510, "providerAsns": [64511, 64512] }}
]
}}
}}"#
);
let slurm = SlurmFile::from_slice(json.as_bytes()).unwrap();
let input = vec![
Payload::RouteOrigin(RouteOrigin::new(
IPAddressPrefix::new(IPAddress::from_ipv4("192.0.2.0".parse().unwrap()), 24),
24,
Asn::from(64496u32),
)),
Payload::RouteOrigin(RouteOrigin::new(
IPAddressPrefix::new(IPAddress::from_ipv4("203.0.113.0".parse().unwrap()), 24),
24,
Asn::from(64497u32),
)),
Payload::RouterKey(RouterKey::new(ski, Asn::from(64497u32), spki.clone())),
Payload::Aspa(Aspa::new(Asn::from(64496u32), vec![Asn::from(64498u32)])),
];
let output = slurm.apply(&input);
assert_eq!(output.len(), 4);
assert!(output.iter().any(|payload| matches!(
payload,
Payload::RouteOrigin(route_origin)
if route_origin.prefix().address() == IPAddress::from_ipv4("203.0.113.0".parse().unwrap())
)));
assert!(output.iter().any(|payload| matches!(
payload,
Payload::RouteOrigin(route_origin)
if route_origin.prefix().address() == IPAddress::from_ipv4("198.51.100.0".parse().unwrap())
)));
assert!(output.iter().any(|payload| matches!(
payload,
Payload::RouterKey(router_key)
if router_key.asn() == Asn::from(64501u32)
)));
assert!(output.iter().any(|payload| matches!(
payload,
Payload::Aspa(aspa)
if aspa.customer_asn() == Asn::from(64510u32)
)));
}
#[test]
fn merges_multiple_slurm_files_without_conflict() {
let a = r#"{
"slurmVersion": 1,
"validationOutputFilters": {
"prefixFilters": [],
"bgpsecFilters": []
},
"locallyAddedAssertions": {
"prefixAssertions": [
{ "prefix": "198.51.100.0/24", "asn": 64500, "maxPrefixLength": 24 }
],
"bgpsecAssertions": []
}
}"#;
let b = r#"{
"slurmVersion": 2,
"validationOutputFilters": {
"prefixFilters": [],
"bgpsecFilters": [],
"aspaFilters": [
{ "customerAsn": 64510 }
]
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": [],
"aspaAssertions": []
}
}"#;
let merged = SlurmFile::merge_named(vec![
(
"a.slurm".to_string(),
SlurmFile::from_slice(a.as_bytes()).unwrap(),
),
(
"b.slurm".to_string(),
SlurmFile::from_slice(b.as_bytes()).unwrap(),
),
])
.unwrap();
assert_eq!(merged.version(), SlurmVersion::V2);
assert_eq!(merged.locally_added_assertions().prefix_assertions.len(), 1);
assert_eq!(merged.validation_output_filters().aspa_filters.len(), 1);
}
#[test]
fn rejects_conflicting_multiple_slurm_files() {
let a = r#"{
"slurmVersion": 1,
"validationOutputFilters": {
"prefixFilters": [],
"bgpsecFilters": []
},
"locallyAddedAssertions": {
"prefixAssertions": [
{ "prefix": "10.0.0.0/8", "asn": 64500, "maxPrefixLength": 24 }
],
"bgpsecAssertions": []
}
}"#;
let b = r#"{
"slurmVersion": 1,
"validationOutputFilters": {
"prefixFilters": [
{ "prefix": "10.0.0.0/16" }
],
"bgpsecFilters": []
},
"locallyAddedAssertions": {
"prefixAssertions": [],
"bgpsecAssertions": []
}
}"#;
let err = SlurmFile::merge_named(vec![
(
"a.slurm".to_string(),
SlurmFile::from_slice(a.as_bytes()).unwrap(),
),
(
"b.slurm".to_string(),
SlurmFile::from_slice(b.as_bytes()).unwrap(),
),
])
.unwrap_err();
assert!(err.to_string().contains("conflicting SLURM files"));
}