内存优化

This commit is contained in:
xiuting.xu 2026-05-06 15:27:16 +08:00
parent 6e4b59a208
commit 19225edaa7
6 changed files with 211 additions and 121 deletions

54
src/rtr/cache/core.rs vendored
View File

@ -56,7 +56,7 @@ impl SessionIds {
pub struct VersionState { pub struct VersionState {
session_id: u16, session_id: u16,
serial: u32, serial: u32,
snapshot: Snapshot, snapshot: Arc<Snapshot>,
#[serde(skip)] #[serde(skip)]
rtr_payloads: Arc<Vec<Payload>>, rtr_payloads: Arc<Vec<Payload>>,
#[serde(skip)] #[serde(skip)]
@ -69,7 +69,7 @@ impl VersionState {
Self { Self {
session_id, session_id,
serial, serial,
snapshot, snapshot: Arc::new(snapshot),
rtr_payloads, rtr_payloads,
deltas: VecDeque::with_capacity(max_delta as usize), deltas: VecDeque::with_capacity(max_delta as usize),
} }
@ -200,7 +200,7 @@ impl RtrCacheBuilder {
session_id: session_ids.as_array()[idx], session_id: session_ids.as_array()[idx],
serial: serials[idx], serial: serials[idx],
rtr_payloads: Arc::new(snapshot.payloads_for_rtr()), rtr_payloads: Arc::new(snapshot.payloads_for_rtr()),
snapshot, snapshot: Arc::new(snapshot),
deltas: deltas[idx].clone(), deltas: deltas[idx].clone(),
} }
}); });
@ -229,7 +229,7 @@ impl RtrCache {
); );
self.availability = CacheAvailability::NoDataAvailable; self.availability = CacheAvailability::NoDataAvailable;
for version_state in &mut self.versions { for version_state in &mut self.versions {
version_state.snapshot = Snapshot::empty(); version_state.snapshot = Arc::new(Snapshot::empty());
version_state.rtr_payloads = Arc::new(Vec::new()); version_state.rtr_payloads = Arc::new(Vec::new());
version_state.deltas.clear(); version_state.deltas.clear();
} }
@ -246,8 +246,8 @@ impl RtrCache {
let state = &mut self.versions[version]; let state = &mut self.versions[version];
state.session_id = new_session_ids.get(v); state.session_id = new_session_ids.get(v);
state.serial = 1; state.serial = 1;
state.snapshot = project_snapshot_for_version(source_snapshot, v); state.snapshot = Arc::new(project_snapshot_for_version(source_snapshot, v));
state.rtr_payloads = Arc::new(state.snapshot.payloads_for_rtr()); state.rtr_payloads = Arc::new(state.snapshot.as_ref().payloads_for_rtr());
state.deltas.clear(); state.deltas.clear();
} }
self.last_update_end = DualTime::now(); self.last_update_end = DualTime::now();
@ -285,7 +285,7 @@ impl RtrCache {
state.deltas.push_back(delta); state.deltas.push_back(delta);
let mut dropped_serials = Vec::new(); let mut dropped_serials = Vec::new();
if prune_delta_by_snapshot_size { if prune_delta_by_snapshot_size {
let snapshot_wire_size = estimate_snapshot_payload_wire_size(&state.snapshot); let snapshot_wire_size = estimate_snapshot_payload_wire_size(state.snapshot.as_ref());
let mut cumulative_delta_wire_size = let mut cumulative_delta_wire_size =
estimate_delta_window_payload_wire_size(&state.deltas); estimate_delta_window_payload_wire_size(&state.deltas);
while !state.deltas.is_empty() && cumulative_delta_wire_size >= snapshot_wire_size { while !state.deltas.is_empty() && cumulative_delta_wire_size >= snapshot_wire_size {
@ -361,8 +361,8 @@ impl RtrCache {
continue; continue;
} }
state.snapshot = projected; state.snapshot = Arc::new(projected);
state.rtr_payloads = Arc::new(state.snapshot.payloads_for_rtr()); state.rtr_payloads = Arc::new(state.snapshot.as_ref().payloads_for_rtr());
Self::push_delta( Self::push_delta(
state, state,
self.max_delta, self.max_delta,
@ -435,7 +435,7 @@ impl RtrCache {
} }
pub fn snapshot_for_version(&self, version: u8) -> Snapshot { pub fn snapshot_for_version(&self, version: u8) -> Snapshot {
self.versions[version_index(version)].snapshot.clone() self.versions[version_index(version)].snapshot.as_ref().clone()
} }
pub fn rtr_payloads_for_version(&self, version: u8) -> Arc<Vec<Payload>> { pub fn rtr_payloads_for_version(&self, version: u8) -> Arc<Vec<Payload>> {
@ -533,22 +533,22 @@ fn collect_deltas_since(state: &VersionState, client_serial: u32) -> Option<Vec<
} }
fn merge_deltas_minimally(current_serial: u32, deltas: &[Arc<Delta>]) -> Delta { fn merge_deltas_minimally(current_serial: u32, deltas: &[Arc<Delta>]) -> Delta {
let mut states = BTreeMap::<ChangeKey, LogicalState>::new(); let mut states = BTreeMap::<ChangeKey, LogicalStateRef<'_>>::new();
for delta in deltas { for delta in deltas {
for payload in delta.withdrawn() { for payload in delta.withdrawn() {
let key = change_key(payload); let key = change_key(payload);
let state = states.entry(key).or_insert_with(LogicalState::new); let state = states.entry(key).or_insert_with(LogicalStateRef::new);
if state.before.is_none() && state.after.is_none() { if state.before.is_none() && state.after.is_none() {
state.before = Some(payload.clone()); state.before = Some(payload);
} }
state.after = None; state.after = None;
} }
for payload in delta.announced() { for payload in delta.announced() {
let key = change_key(payload); let key = change_key(payload);
let state = states.entry(key).or_insert_with(LogicalState::new); let state = states.entry(key).or_insert_with(LogicalStateRef::new);
state.after = Some(payload.clone()); state.after = Some(payload);
} }
} }
@ -557,17 +557,17 @@ fn merge_deltas_minimally(current_serial: u32, deltas: &[Arc<Delta>]) -> Delta {
for (_key, state) in states { for (_key, state) in states {
match (state.before, state.after) { match (state.before, state.after) {
(None, None) => {} (None, None) => {}
(None, Some(new_payload)) => announced.push(new_payload), (None, Some(new_payload)) => announced.push(new_payload.clone()),
(Some(old_payload), None) => withdrawn.push(old_payload), (Some(old_payload), None) => withdrawn.push(old_payload.clone()),
(Some(old_payload), Some(new_payload)) => { (Some(old_payload), Some(new_payload)) => {
if old_payload != new_payload { if old_payload != new_payload {
if matches!(old_payload, Payload::Aspa(_)) if matches!(old_payload, Payload::Aspa(_))
&& matches!(new_payload, Payload::Aspa(_)) && matches!(new_payload, Payload::Aspa(_))
{ {
announced.push(new_payload); announced.push(new_payload.clone());
} else { } else {
withdrawn.push(old_payload); withdrawn.push(old_payload.clone());
announced.push(new_payload); announced.push(new_payload.clone());
} }
} }
} }
@ -609,7 +609,7 @@ fn project_payload_for_version(payload: &Payload, version: u8) -> Option<Payload
fn estimate_snapshot_payload_wire_size(snapshot: &Snapshot) -> usize { fn estimate_snapshot_payload_wire_size(snapshot: &Snapshot) -> usize {
snapshot snapshot
.payloads_for_rtr() .rtr_payloads_for_rtr_arc()
.iter() .iter()
.map(|payload| estimate_payload_wire_size(payload, true)) .map(|payload| estimate_payload_wire_size(payload, true))
.sum() .sum()
@ -648,13 +648,13 @@ fn estimate_payload_wire_size(payload: &Payload, announce: bool) -> usize {
} }
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Copy, Default)]
struct LogicalState { struct LogicalStateRef<'a> {
before: Option<Payload>, before: Option<&'a Payload>,
after: Option<Payload>, after: Option<&'a Payload>,
} }
impl LogicalState { impl<'a> LogicalStateRef<'a> {
fn new() -> Self { fn new() -> Self {
Self { Self {
before: None, before: None,
@ -671,7 +671,7 @@ pub enum SerialResult {
pub(super) struct AppliedUpdate { pub(super) struct AppliedUpdate {
pub(super) availability: CacheAvailability, pub(super) availability: CacheAvailability,
pub(super) snapshots: [Snapshot; VERSION_COUNT], pub(super) snapshots: [Arc<Snapshot>; VERSION_COUNT],
pub(super) serials: [u32; VERSION_COUNT], pub(super) serials: [u32; VERSION_COUNT],
pub(super) session_ids: [u16; VERSION_COUNT], pub(super) session_ids: [u16; VERSION_COUNT],
pub(super) deltas: [Option<Arc<Delta>>; VERSION_COUNT], pub(super) deltas: [Option<Arc<Delta>>; VERSION_COUNT],

View File

@ -1,5 +1,5 @@
use std::collections::{BTreeMap, BTreeSet}; use std::collections::{BTreeMap, BTreeSet};
use std::sync::OnceLock; use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use chrono::{DateTime, NaiveDateTime, Utc}; use chrono::{DateTime, NaiveDateTime, Utc};
@ -78,6 +78,8 @@ pub struct Snapshot {
router_keys_hash: [u8; 32], router_keys_hash: [u8; 32],
aspas_hash: [u8; 32], aspas_hash: [u8; 32],
snapshot_hash: [u8; 32], snapshot_hash: [u8; 32],
#[serde(skip)]
rtr_payloads_cache: OnceLock<Arc<Vec<Payload>>>,
} }
impl Snapshot { impl Snapshot {
@ -95,8 +97,11 @@ impl Snapshot {
router_keys_hash: [0u8; 32], router_keys_hash: [0u8; 32],
aspas_hash: [0u8; 32], aspas_hash: [0u8; 32],
snapshot_hash: [0u8; 32], snapshot_hash: [0u8; 32],
rtr_payloads_cache: OnceLock::new(),
}; };
snapshot.recompute_hashes(); snapshot.recompute_hashes();
let cached = build_snapshot_payloads_for_rtr(&snapshot);
let _ = snapshot.rtr_payloads_cache.set(Arc::new(cached));
snapshot snapshot
} }
@ -223,9 +228,13 @@ impl Snapshot {
} }
pub fn payloads_for_rtr(&self) -> Vec<Payload> { pub fn payloads_for_rtr(&self) -> Vec<Payload> {
let mut payloads = self.payloads(); self.rtr_payloads_for_rtr_arc().as_ref().clone()
sort_payloads_for_rtr(&mut payloads, true); }
payloads
pub fn rtr_payloads_for_rtr_arc(&self) -> Arc<Vec<Payload>> {
self.rtr_payloads_cache
.get_or_init(|| Arc::new(build_snapshot_payloads_for_rtr(self)))
.clone()
} }
pub fn origins_hash(&self) -> [u8; 32] { pub fn origins_hash(&self) -> [u8; 32] {
@ -354,6 +363,12 @@ fn build_payload_updates_for_rtr(
updates updates
} }
fn build_snapshot_payloads_for_rtr(snapshot: &Snapshot) -> Vec<Payload> {
let mut payloads = snapshot.payloads();
sort_payloads_for_rtr(&mut payloads, true);
payloads
}
fn dedup_payloads(payloads: &mut Vec<Payload>) { fn dedup_payloads(payloads: &mut Vec<Payload>) {
let mut seen = BTreeSet::new(); let mut seen = BTreeSet::new();
payloads.retain(|p| seen.insert(p.clone())); payloads.retain(|p| seen.insert(p.clone()));

View File

@ -1,7 +1,8 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc; use std::sync::{Arc, OnceLock};
use anyhow::Result; use anyhow::Result;
use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
use crate::rtr::payload::{Payload, Timing}; use crate::rtr::payload::{Payload, Timing};
use crate::rtr::store::RtrStore; use crate::rtr::store::RtrStore;
@ -10,6 +11,12 @@ use super::core::{AppliedUpdate, CacheAvailability, RtrCache, RtrCacheBuilder, S
use super::model::{Delta, Snapshot}; use super::model::{Delta, Snapshot};
const VERSION_COUNT: usize = 3; const VERSION_COUNT: usize = 3;
static STORE_SYNC_WORKER: OnceLock<UnboundedSender<StoreSyncJob>> = OnceLock::new();
struct StoreSyncJob {
store: RtrStore,
update: AppliedUpdate,
}
impl RtrCache { impl RtrCache {
pub fn init( pub fn init(
@ -164,25 +171,46 @@ fn try_restore_from_store(
} }
fn spawn_store_sync(store: &RtrStore, update: AppliedUpdate) { fn spawn_store_sync(store: &RtrStore, update: AppliedUpdate) {
tokio::spawn({ let tx = STORE_SYNC_WORKER.get_or_init(|| {
let store = store.clone(); let (tx, mut rx) = unbounded_channel::<StoreSyncJob>();
async move { tokio::spawn(async move {
while let Some(mut job) = rx.recv().await {
while let Ok(next) = rx.try_recv() {
job = next;
}
persist_update_job(job);
}
});
tx
});
if let Err(err) = tx.send(StoreSyncJob {
store: store.clone(),
update,
}) {
tracing::warn!(
"store sync worker channel closed, falling back to inline persist: {:?}",
err
);
persist_update_job(err.0);
}
}
fn persist_update_job(job: StoreSyncJob) {
let delta_refs: [Option<&Delta>; 3] = let delta_refs: [Option<&Delta>; 3] =
std::array::from_fn(|idx| update.deltas[idx].as_deref()); std::array::from_fn(|idx| job.update.deltas[idx].as_deref());
if let Err(e) = store.save_cache_state_versioned( if let Err(e) = job.store.save_cache_state_versioned(
update.availability, job.update.availability,
&update.snapshots, &job.update.snapshots,
&update.session_ids, &job.update.session_ids,
&update.serials, &job.update.serials,
&delta_refs, &delta_refs,
&update.delta_windows, &job.update.delta_windows,
&update.clear_delta_windows, &job.update.clear_delta_windows,
) { ) {
tracing::error!("persist cache state failed: {:?}", e); tracing::error!("persist cache state failed: {:?}", e);
} }
} }
});
}
fn project_snapshot_for_version(snapshot: &Snapshot, version: u8) -> Snapshot { fn project_snapshot_for_version(snapshot: &Snapshot, version: u8) -> Snapshot {
let mut payloads = Vec::new(); let mut payloads = Vec::new();

View File

@ -5,7 +5,7 @@ use std::time::{Duration, Instant};
use anyhow::{Result, anyhow, bail}; use anyhow::{Result, anyhow, bail};
use tokio::io; use tokio::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufWriter};
use tokio::sync::{broadcast, watch}; use tokio::sync::{broadcast, watch};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
@ -735,9 +735,16 @@ where
client_serial, client_serial,
offending_pdu.len() offending_pdu.len()
); );
let (data_available, current_session, current_serial, timing, serial_result) = {
let cache = self.cache.load_full(); let cache = self.cache.load_full();
let data_available = cache.is_data_available(); (
let current_session = cache.session_id_for_version(version); cache.is_data_available(),
cache.session_id_for_version(version),
cache.serial_for_version(version),
cache.timing(),
cache.get_deltas_since_for_version(version, client_serial),
)
};
if !data_available { if !data_available {
self.send_no_data_available(offending_pdu, "cache data is not currently available") self.send_no_data_available(offending_pdu, "cache data is not currently available")
@ -765,11 +772,6 @@ where
); );
} }
let serial_result = cache.get_deltas_since_for_version(version, client_serial);
let current_session = cache.session_id_for_version(version);
let current_serial = cache.serial_for_version(version);
let timing = cache.timing();
match serial_result { match serial_result {
SerialResult::ResetRequired => { SerialResult::ResetRequired => {
self.write_cache_reset().await?; self.write_cache_reset().await?;
@ -784,11 +786,10 @@ where
} }
SerialResult::UpToDate => { SerialResult::UpToDate => {
self.write_cache_response(current_session).await?;
self.write_end_of_data(current_session, current_serial, timing) self.write_end_of_data(current_session, current_serial, timing)
.await?; .await?;
debug!( debug!(
"RTR session replied CacheResponse+EndOfData (up-to-date) to Serial Query: client_session_id={}, client_serial={}, response_session_id={}, response_serial={}, state={}, negotiated_version={:?}", "RTR session replied EndOfData (up-to-date) to Serial Query: client_session_id={}, client_serial={}, response_session_id={}, response_serial={}, state={}, negotiated_version={:?}",
client_session, client_session,
client_serial, client_serial,
current_session, current_session,
@ -976,6 +977,7 @@ where
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12
validate_payloads_for_rtr(payloads, announce).map_err(|err| anyhow!(err.to_string()))?; validate_payloads_for_rtr(payloads, announce).map_err(|err| anyhow!(err.to_string()))?;
let version = self.version()?;
let (route_origins, router_keys, aspas) = count_payloads(payloads); let (route_origins, router_keys, aspas) = count_payloads(payloads);
debug!( debug!(
"RTR session sending snapshot payloads: announce={}, total={}, route_origins={}, router_keys={}, aspas={}", "RTR session sending snapshot payloads: announce={}, total={}, route_origins={}, router_keys={}, aspas={}",
@ -985,9 +987,11 @@ where
router_keys, router_keys,
aspas aspas
); );
let mut writer = BufWriter::new(&mut self.stream);
for payload in payloads { for payload in payloads {
self.send_payload(payload, announce).await?; Self::send_payload_to(&mut writer, payload, announce, version).await?;
} }
writer.flush().await?;
Ok(()) Ok(())
} }
@ -1001,6 +1005,7 @@ where
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-11.4
// https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12 // https://datatracker.ietf.org/doc/html/draft-ietf-sidrops-8210bis-25#section-12
validate_payload_updates_for_rtr(&updates).map_err(|err| anyhow!(err.to_string()))?; validate_payload_updates_for_rtr(&updates).map_err(|err| anyhow!(err.to_string()))?;
let version = self.version()?;
let (announced, withdrawn, route_origins, router_keys, aspas) = let (announced, withdrawn, route_origins, router_keys, aspas) =
count_payload_updates(&updates); count_payload_updates(&updates);
debug!( debug!(
@ -1012,35 +1017,50 @@ where
router_keys, router_keys,
aspas aspas
); );
let mut writer = BufWriter::new(&mut self.stream);
for (announce, payload) in updates { for (announce, payload) in updates {
self.send_payload(payload, *announce).await?; Self::send_payload_to(&mut writer, payload, *announce, version).await?;
} }
writer.flush().await?;
Ok(()) Ok(())
} }
async fn send_payload(&mut self, payload: &Payload, announce: bool) -> Result<()> { async fn send_payload_to<W>(
let version = self.version()?; writer: &mut W,
payload: &Payload,
announce: bool,
version: u8,
) -> Result<()>
where
W: AsyncWrite + Unpin,
{
match payload { match payload {
Payload::RouteOrigin(origin) => { Payload::RouteOrigin(origin) => {
self.send_route_origin(origin, announce).await?; Self::send_route_origin_to(writer, origin, announce, version).await?;
} }
Payload::RouterKey(key) => { Payload::RouterKey(key) => {
if version >= 1 { if version >= 1 {
self.send_router_key(key, announce).await?; Self::send_router_key_to(writer, key, announce, version).await?;
} }
} }
Payload::Aspa(aspa) => { Payload::Aspa(aspa) => {
if version >= 2 { if version >= 2 {
self.send_aspa(aspa, announce).await?; Self::send_aspa_to(writer, aspa, announce, version).await?;
} }
} }
} }
Ok(()) Ok(())
} }
async fn send_route_origin(&mut self, origin: &RouteOrigin, announce: bool) -> Result<()> { async fn send_route_origin_to<W>(
let version = self.version()?; writer: &mut W,
origin: &RouteOrigin,
announce: bool,
version: u8,
) -> Result<()>
where
W: AsyncWrite + Unpin,
{
let flags = Flags::new(if announce { let flags = Flags::new(if announce {
ANNOUNCE_FLAG ANNOUNCE_FLAG
} else { } else {
@ -1054,12 +1074,12 @@ where
match prefix.address { match prefix.address {
IPAddress::V4(v4) => { IPAddress::V4(v4) => {
IPv4Prefix::new(version, flags, prefix_len, max_len, v4, origin.asn()) IPv4Prefix::new(version, flags, prefix_len, max_len, v4, origin.asn())
.write(&mut self.stream) .write(writer)
.await?; .await?;
} }
IPAddress::V6(v6) => { IPAddress::V6(v6) => {
IPv6Prefix::new(version, flags, prefix_len, max_len, v6, origin.asn()) IPv6Prefix::new(version, flags, prefix_len, max_len, v6, origin.asn())
.write(&mut self.stream) .write(writer)
.await?; .await?;
} }
} }
@ -1067,8 +1087,15 @@ where
Ok(()) Ok(())
} }
async fn send_router_key(&mut self, key: &RouterKey, announce: bool) -> Result<()> { async fn send_router_key_to<W>(
let version = self.version()?; writer: &mut W,
key: &RouterKey,
announce: bool,
version: u8,
) -> Result<()>
where
W: AsyncWrite + Unpin,
{
key.validate()?; key.validate()?;
let flags = Flags::new(if announce { let flags = Flags::new(if announce {
@ -1085,13 +1112,19 @@ where
Arc::<[u8]>::from(key.spki().to_vec()), Arc::<[u8]>::from(key.spki().to_vec()),
); );
pdu.write(&mut self.stream).await?; pdu.write(writer).await?;
Ok(()) Ok(())
} }
async fn send_aspa(&mut self, aspa: &Aspa, announce: bool) -> Result<()> { async fn send_aspa_to<W>(
let version = self.version()?; writer: &mut W,
aspa: &Aspa,
announce: bool,
version: u8,
) -> Result<()>
where
W: AsyncWrite + Unpin,
{
if announce { if announce {
aspa.validate_announcement()?; aspa.validate_announcement()?;
} }
@ -1113,7 +1146,7 @@ where
let pdu = AspaPdu::new(version, flags, aspa.customer_asn().into_u32(), providers); let pdu = AspaPdu::new(version, flags, aspa.customer_asn().into_u32(), providers);
pdu.write(&mut self.stream).await?; pdu.write(writer).await?;
Ok(()) Ok(())
} }

View File

@ -1,6 +1,7 @@
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use rocksdb::{ColumnFamilyDescriptor, DB, IteratorMode, Options, WriteBatch}; use rocksdb::{ColumnFamilyDescriptor, DB, IteratorMode, Options, WriteBatch};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::borrow::Borrow;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tracing::{info, warn}; use tracing::{info, warn};
@ -28,6 +29,14 @@ fn delta_key_v2(version: u8, serial: u32) -> [u8; 6] {
key key
} }
fn delta_prefix_start(version: u8) -> [u8; 6] {
delta_key_v2(version, 0)
}
fn delta_prefix_end_exclusive(version: u8) -> [u8; 6] {
delta_key_v2(version.saturating_add(1), 0)
}
fn delta_key_v2_serial(key: &[u8]) -> Option<(u8, u32)> { fn delta_key_v2_serial(key: &[u8]) -> Option<(u8, u32)> {
if key.len() != 6 || key[0] != DELTA_KEY_V2_PREFIX { if key.len() != 6 || key[0] != DELTA_KEY_V2_PREFIX {
return None; return None;
@ -201,16 +210,19 @@ impl RtrStore {
/// - Direct callers should be limited to DB contract tests. /// - Direct callers should be limited to DB contract tests.
/// - Do not introduce ad-hoc write paths outside cache/store, otherwise /// - Do not introduce ad-hoc write paths outside cache/store, otherwise
/// session_id/serial/snapshot/delta_window consistency can be broken. /// session_id/serial/snapshot/delta_window consistency can be broken.
pub fn save_cache_state_versioned( pub fn save_cache_state_versioned<S>(
&self, &self,
availability: CacheAvailability, availability: CacheAvailability,
snapshots: &[Snapshot; 3], snapshots: &[S; 3],
session_ids: &[u16; 3], session_ids: &[u16; 3],
serials: &[u32; 3], serials: &[u32; 3],
deltas: &[Option<&Delta>; 3], deltas: &[Option<&Delta>; 3],
delta_windows: &[Option<(u32, u32)>; 3], delta_windows: &[Option<(u32, u32)>; 3],
clear_delta_windows: &[bool; 3], clear_delta_windows: &[bool; 3],
) -> Result<()> { ) -> Result<()>
where
S: Borrow<Snapshot>,
{
let snapshot_cf = self let snapshot_cf = self
.db .db
.cf_handle(CF_SNAPSHOT) .cf_handle(CF_SNAPSHOT)
@ -236,7 +248,7 @@ impl RtrStore {
batch.put_cf( batch.put_cf(
snapshot_cf, snapshot_cf,
snapshot_key(version), snapshot_key(version),
serde_json::to_vec(&snapshots[idx])?, serde_json::to_vec(snapshots[idx].borrow())?,
); );
batch.put_cf( batch.put_cf(
meta_cf, meta_cf,
@ -260,9 +272,11 @@ impl RtrStore {
if clear_delta_windows[idx] { if clear_delta_windows[idx] {
batch.delete_cf(meta_cf, meta_key(META_DELTA_MIN_PREFIX, version)); batch.delete_cf(meta_cf, meta_key(META_DELTA_MIN_PREFIX, version));
batch.delete_cf(meta_cf, meta_key(META_DELTA_MAX_PREFIX, version)); batch.delete_cf(meta_cf, meta_key(META_DELTA_MAX_PREFIX, version));
for key in self.list_delta_keys_for_version(version)? { batch.delete_range_cf(
batch.delete_cf(delta_cf, key); delta_cf,
} delta_prefix_start(version),
delta_prefix_end_exclusive(version),
);
} else if let Some((min_serial, max_serial)) = delta_windows[idx] { } else if let Some((min_serial, max_serial)) = delta_windows[idx] {
batch.put_cf( batch.put_cf(
meta_cf, meta_cf,
@ -274,6 +288,22 @@ impl RtrStore {
meta_key(META_DELTA_MAX_PREFIX, version), meta_key(META_DELTA_MAX_PREFIX, version),
serde_json::to_vec(&max_serial)?, serde_json::to_vec(&max_serial)?,
); );
if min_serial <= max_serial {
if min_serial > 0 {
batch.delete_range_cf(
delta_cf,
delta_prefix_start(version),
delta_key_v2(version, min_serial),
);
}
if max_serial < u32::MAX {
batch.delete_range_cf(
delta_cf,
delta_key_v2(version, max_serial.wrapping_add(1)),
delta_prefix_end_exclusive(version),
);
}
} else {
for key in self for key in self
.list_delta_keys_outside_window_for_version(version, min_serial, max_serial)? .list_delta_keys_outside_window_for_version(version, min_serial, max_serial)?
{ {
@ -281,6 +311,7 @@ impl RtrStore {
} }
} }
} }
}
self.db.write(batch)?; self.db.write(batch)?;
Ok(()) Ok(())

View File

@ -5,8 +5,9 @@ use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use arc_swap::ArcSwap;
use rustls::{ClientConfig, RootCertStore}; use rustls::{ClientConfig, RootCertStore};
use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use serde_json::json; use serde_json::json;
@ -37,7 +38,7 @@ use rpki::rtr::server::tls::load_rustls_server_config_with_options;
use rpki::rtr::session::RtrSession; use rpki::rtr::session::RtrSession;
fn shared_cache(cache: rpki::rtr::cache::RtrCache) -> SharedRtrCache { fn shared_cache(cache: rpki::rtr::cache::RtrCache) -> SharedRtrCache {
Arc::new(RwLock::new(cache)) Arc::new(ArcSwap::from_pointee(cache))
} }
async fn start_session_server( async fn start_session_server(
@ -450,7 +451,7 @@ async fn restart_restores_versioned_state_and_serves_queries() {
ResetQuery::new(0).write(&mut client).await.unwrap(); ResetQuery::new(0).write(&mut client).await.unwrap();
let response = CacheResponse::read(&mut client).await.unwrap(); let response = CacheResponse::read(&mut client).await.unwrap();
assert_eq!(response.version(), 0); assert_eq!(response.version(), 0);
let expected_sid_v0 = shared.read().unwrap().session_id_for_version(0); let expected_sid_v0 = shared.load_full().session_id_for_version(0);
assert_eq!(response.session_id(), expected_sid_v0); assert_eq!(response.session_id(), expected_sid_v0);
let _v4 = IPv4Prefix::read(&mut client).await.unwrap(); let _v4 = IPv4Prefix::read(&mut client).await.unwrap();
let eod_v0 = rpki::rtr::pdu::EndOfDataV0::read(&mut client).await.unwrap(); let eod_v0 = rpki::rtr::pdu::EndOfDataV0::read(&mut client).await.unwrap();
@ -462,7 +463,7 @@ async fn restart_restores_versioned_state_and_serves_queries() {
ResetQuery::new(1).write(&mut client).await.unwrap(); ResetQuery::new(1).write(&mut client).await.unwrap();
let response = CacheResponse::read(&mut client).await.unwrap(); let response = CacheResponse::read(&mut client).await.unwrap();
assert_eq!(response.version(), 1); assert_eq!(response.version(), 1);
let expected_sid_v1 = shared.read().unwrap().session_id_for_version(1); let expected_sid_v1 = shared.load_full().session_id_for_version(1);
assert_eq!(response.session_id(), expected_sid_v1); assert_eq!(response.session_id(), expected_sid_v1);
let _v4 = IPv4Prefix::read(&mut client).await.unwrap(); let _v4 = IPv4Prefix::read(&mut client).await.unwrap();
let _rk = RouterKeyPdu::read(&mut client).await.unwrap(); let _rk = RouterKeyPdu::read(&mut client).await.unwrap();
@ -472,7 +473,7 @@ async fn restart_restores_versioned_state_and_serves_queries() {
let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(shared.clone()).await; let (addr, _notify_tx, shutdown_tx, server_handle) = start_session_server(shared.clone()).await;
let mut client = TcpStream::connect(addr).await.unwrap(); let mut client = TcpStream::connect(addr).await.unwrap();
let sid_v2 = shared.read().unwrap().session_id_for_version(2); let sid_v2 = shared.load_full().session_id_for_version(2);
SerialQuery::new(2, sid_v2, 1).write(&mut client).await.unwrap(); SerialQuery::new(2, sid_v2, 1).write(&mut client).await.unwrap();
let response = CacheResponse::read(&mut client).await.unwrap(); let response = CacheResponse::read(&mut client).await.unwrap();
assert_eq!(response.version(), 2); assert_eq!(response.version(), 2);
@ -1654,7 +1655,7 @@ async fn version_one_sends_router_key_but_not_aspa() {
} }
#[tokio::test] #[tokio::test]
async fn established_session_idle_timeout_returns_transport_failed() { async fn established_session_idle_does_not_trigger_transport_failed() {
let cache = RtrCacheBuilder::new() let cache = RtrCacheBuilder::new()
.session_ids(SessionIds::from_array([42, 42, 42])) .session_ids(SessionIds::from_array([42, 42, 42]))
.serials(serials_all(100)) .serials(serials_all(100))
@ -1675,38 +1676,20 @@ async fn established_session_idle_timeout_returns_transport_failed() {
let eod = EndOfDataV1::read(&mut client).await.unwrap(); let eod = EndOfDataV1::read(&mut client).await.unwrap();
dump.push_value(eod.pdu(), dump_eod_v1(&eod)); dump.push_value(eod.pdu(), dump_eod_v1(&eod));
let report = timeout(Duration::from_secs(1), ErrorReport::read(&mut client)) let idle_res = timeout(Duration::from_millis(300), Header::read(&mut client)).await;
.await assert!(
.expect("timed out waiting for transport failure ErrorReport") idle_res.is_err(),
.unwrap(); "established session should stay idle without sending TransportFailed"
dump.push_value(
ErrorReport::PDU,
json!({
"version": report.version(),
"pdu": ErrorReport::PDU,
"pdu_name": "Error Report",
"error_code": report.error_code().map(|code| code.as_u16()).unwrap_or_else(|code| code),
"erroneous_pdu_len": report.erroneous_pdu().len(),
"erroneous_pdu_hex": common::test_helper::bytes_to_hex(report.erroneous_pdu()),
"text": String::from_utf8_lossy(report.text()),
}),
); );
assert_eq!(report.version(), 1);
assert_eq!(report.error_code(), Ok(ErrorCode::TransportFailed));
assert!(report.erroneous_pdu().is_empty());
assert!(std::str::from_utf8(report.text()).unwrap().contains("transport stalled"));
let read_res = Header::read(&mut client).await;
assert!(read_res.is_err());
dump.push_value( dump.push_value(
0, 0,
json!({ json!({
"event": "connection_closed_after_transport_timeout", "event": "no_transport_failed_on_established_idle",
"result": "header_read_failed_as_expected" "result": "no_pdu_received_within_timeout",
"timeout_ms": 300
}), }),
); );
dump.print_pretty("established_session_idle_timeout_returns_transport_failed"); dump.print_pretty("established_session_idle_does_not_trigger_transport_failed");
shutdown_server(client, shutdown_tx, server_handle).await; shutdown_server(client, shutdown_tx, server_handle).await;
} }