rpki/src/rtr/cache/core.rs

690 lines
21 KiB
Rust

use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BTreeMap, VecDeque};
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::rtr::payload::{Payload, Timing};
use super::model::{Delta, DualTime, Snapshot};
use super::ordering::{ChangeKey, change_key};
const SERIAL_HALF_RANGE: u32 = 1 << 31;
const VERSION_COUNT: usize = 3;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq)]
pub enum CacheAvailability {
Ready,
NoDataAvailable,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct SessionIds {
ids: [u16; VERSION_COUNT],
}
impl SessionIds {
pub fn from_array(ids: [u16; VERSION_COUNT]) -> Self {
Self { ids }
}
pub fn random_distinct() -> Self {
let mut ids = [0u16; VERSION_COUNT];
for idx in 0..ids.len() {
loop {
let candidate: u16 = rand::random();
if ids[..idx].iter().all(|existing| *existing != candidate) {
ids[idx] = candidate;
break;
}
}
}
Self { ids }
}
pub fn get(&self, version: u8) -> u16 {
self.ids[version_index(version)]
}
pub fn as_array(&self) -> [u16; VERSION_COUNT] {
self.ids
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionState {
session_id: u16,
serial: u32,
snapshot: Arc<Snapshot>,
#[serde(skip)]
rtr_payloads: Arc<Vec<Payload>>,
#[serde(skip)]
deltas: VecDeque<Arc<Delta>>,
}
impl VersionState {
fn new(session_id: u16, serial: u32, snapshot: Snapshot, max_delta: u8) -> Self {
let rtr_payloads = snapshot.rtr_payloads_for_rtr_arc();
Self {
session_id,
serial,
snapshot: Arc::new(snapshot),
rtr_payloads,
deltas: VecDeque::with_capacity(max_delta as usize),
}
}
}
#[derive(Debug, Clone)]
pub struct RtrCache {
availability: CacheAvailability,
versions: [VersionState; VERSION_COUNT],
max_delta: u8,
prune_delta_by_snapshot_size: bool,
timing: Timing,
last_update_begin: DualTime,
last_update_end: DualTime,
created_at: DualTime,
}
impl Default for RtrCache {
fn default() -> Self {
let now = DualTime::now();
let session_ids = SessionIds::random_distinct();
let versions = std::array::from_fn(|idx| {
VersionState::new(session_ids.as_array()[idx], 0, Snapshot::empty(), 100)
});
Self {
availability: CacheAvailability::Ready,
versions,
max_delta: 100,
prune_delta_by_snapshot_size: false,
timing: Timing::default(),
last_update_begin: now.clone(),
last_update_end: now.clone(),
created_at: now,
}
}
}
pub struct RtrCacheBuilder {
availability: Option<CacheAvailability>,
session_ids: Option<SessionIds>,
max_delta: Option<u8>,
prune_delta_by_snapshot_size: Option<bool>,
timing: Option<Timing>,
serials: Option<[u32; VERSION_COUNT]>,
snapshots: Option<[Snapshot; VERSION_COUNT]>,
deltas: Option<[VecDeque<Arc<Delta>>; VERSION_COUNT]>,
created_at: Option<DualTime>,
}
impl RtrCacheBuilder {
pub fn new() -> Self {
Self {
availability: None,
session_ids: None,
max_delta: None,
prune_delta_by_snapshot_size: None,
timing: None,
serials: None,
snapshots: None,
deltas: None,
created_at: None,
}
}
pub fn session_ids(mut self, v: SessionIds) -> Self {
self.session_ids = Some(v);
self
}
pub fn availability(mut self, v: CacheAvailability) -> Self {
self.availability = Some(v);
self
}
pub fn max_delta(mut self, v: u8) -> Self {
self.max_delta = Some(v);
self
}
pub fn prune_delta_by_snapshot_size(mut self, v: bool) -> Self {
self.prune_delta_by_snapshot_size = Some(v);
self
}
pub fn timing(mut self, v: Timing) -> Self {
self.timing = Some(v);
self
}
pub fn serials(mut self, v: [u32; VERSION_COUNT]) -> Self {
self.serials = Some(v);
self
}
pub fn snapshots(mut self, v: [Snapshot; VERSION_COUNT]) -> Self {
self.snapshots = Some(v);
self
}
pub fn deltas_by_version(mut self, v: [VecDeque<Arc<Delta>>; VERSION_COUNT]) -> Self {
self.deltas = Some(v);
self
}
pub fn created_at(mut self, v: DualTime) -> Self {
self.created_at = Some(v);
self
}
pub fn build(self) -> RtrCache {
let now = DualTime::now();
let max_delta = self.max_delta.unwrap_or(100);
let prune_delta_by_snapshot_size = self.prune_delta_by_snapshot_size.unwrap_or(false);
let timing = self.timing.unwrap_or_default();
let session_ids = self.session_ids.unwrap_or_else(SessionIds::random_distinct);
let serials = self.serials.unwrap_or([0; VERSION_COUNT]);
let snapshots = self
.snapshots
.unwrap_or_else(|| std::array::from_fn(|_| Snapshot::empty()));
let deltas = self.deltas.unwrap_or_else(|| {
std::array::from_fn(|_| VecDeque::with_capacity(max_delta as usize))
});
let versions = std::array::from_fn(|idx| {
let snapshot = snapshots[idx].clone();
VersionState {
session_id: session_ids.as_array()[idx],
serial: serials[idx],
rtr_payloads: snapshot.rtr_payloads_for_rtr_arc(),
snapshot: Arc::new(snapshot),
deltas: deltas[idx].clone(),
}
});
let created_at = self.created_at.unwrap_or_else(|| now.clone());
let availability = self.availability.unwrap_or(CacheAvailability::Ready);
RtrCache {
availability,
versions,
max_delta,
prune_delta_by_snapshot_size,
timing,
last_update_begin: now.clone(),
last_update_end: now,
created_at,
}
}
}
impl RtrCache {
fn set_unavailable(&mut self) {
warn!(
"RTR cache entering NoDataAvailable: serials={:?}",
self.serials()
);
self.availability = CacheAvailability::NoDataAvailable;
for version_state in &mut self.versions {
version_state.snapshot = Arc::new(Snapshot::empty());
version_state.rtr_payloads = version_state.snapshot.rtr_payloads_for_rtr_arc();
version_state.deltas.clear();
}
}
fn reinitialize_from_snapshot(&mut self, source_snapshot: &Snapshot) -> AppliedUpdate {
let old_serials = self.serials();
let old_session_ids = self.session_ids();
let new_session_ids = SessionIds::random_distinct();
self.availability = CacheAvailability::Ready;
for version in 0..VERSION_COUNT {
let v = version as u8;
let state = &mut self.versions[version];
state.session_id = new_session_ids.get(v);
state.serial = 1;
state.snapshot = Arc::new(project_snapshot_for_version(source_snapshot, v));
state.rtr_payloads = state.snapshot.rtr_payloads_for_rtr_arc();
state.deltas.clear();
}
self.last_update_end = DualTime::now();
info!(
"RTR cache reinitialized from usable snapshot: old_serials={:?}, new_serials={:?}, old_session_ids={:?}, new_session_ids={:?}",
old_serials,
self.serials(),
old_session_ids,
new_session_ids
);
self.applied_update_with_clear()
}
fn next_serial(state: &mut VersionState) -> u32 {
let old = state.serial;
state.serial = state.serial.wrapping_add(1);
debug!(
"RTR cache advanced serial for version state: old_serial={}, new_serial={}",
old, state.serial
);
state.serial
}
fn push_delta(
state: &mut VersionState,
max_delta: u8,
prune_delta_by_snapshot_size: bool,
delta: Arc<Delta>,
) {
let max_keep = usize::from(max_delta.max(1));
while state.deltas.len() >= max_keep {
state.deltas.pop_front();
}
state.deltas.push_back(delta);
let mut dropped_serials = Vec::new();
if prune_delta_by_snapshot_size {
let snapshot_wire_size = estimate_snapshot_payload_wire_size(state.snapshot.as_ref());
let mut cumulative_delta_wire_size =
estimate_delta_window_payload_wire_size(&state.deltas);
while !state.deltas.is_empty() && cumulative_delta_wire_size >= snapshot_wire_size {
if let Some(oldest) = state.deltas.pop_front() {
dropped_serials.push(oldest.serial());
cumulative_delta_wire_size =
estimate_delta_window_payload_wire_size(&state.deltas);
}
}
debug!(
"RTR cache delta-size pruning evaluated: snapshot_wire_size={}, cumulative_delta_wire_size={}, dropped_serials={:?}",
snapshot_wire_size, cumulative_delta_wire_size, dropped_serials
);
}
}
fn delta_window(state: &VersionState) -> Option<(u32, u32)> {
let min = state.deltas.front().map(|d| d.serial());
let max = state.deltas.back().map(|d| d.serial());
match (min, max) {
(Some(min), Some(max)) => Some((min, max)),
_ => None,
}
}
pub(super) fn apply_update(
&mut self,
new_payloads: Vec<Payload>,
) -> Result<Option<AppliedUpdate>> {
let source_snapshot = Snapshot::from_payloads(new_payloads);
self.apply_update_from_snapshot(source_snapshot)
}
pub(super) fn apply_update_from_snapshot(
&mut self,
source_snapshot: Snapshot,
) -> Result<Option<AppliedUpdate>> {
self.last_update_begin = DualTime::now();
info!(
"RTR cache applying update: availability={:?}, current_serials={:?}, incoming_snapshot_sizes=(origins={}, router_keys={}, aspas={})",
self.availability,
self.serials(),
source_snapshot.origins().len(),
source_snapshot.router_keys().len(),
source_snapshot.aspas().len()
);
if source_snapshot.is_empty() {
let changed = self.availability != CacheAvailability::NoDataAvailable
|| self.versions.iter().any(|state| !state.snapshot.is_empty())
|| self.versions.iter().any(|state| !state.deltas.is_empty());
self.set_unavailable();
self.last_update_end = DualTime::now();
if !changed {
return Ok(None);
}
return Ok(Some(self.applied_update_with_clear()));
}
if self.availability == CacheAvailability::NoDataAvailable {
return Ok(Some(self.reinitialize_from_snapshot(&source_snapshot)));
}
let mut changed_any = false;
for version in 0..VERSION_COUNT {
let v = version as u8;
let projected = project_snapshot_for_version(&source_snapshot, v);
let state = &mut self.versions[version];
if state.snapshot.same_content(&projected) {
continue;
}
let (announced, withdrawn) = state.snapshot.diff(&projected);
if announced.is_empty() && withdrawn.is_empty() {
continue;
}
let new_serial = Self::next_serial(state);
let delta = Arc::new(Delta::new(new_serial, announced, withdrawn));
if delta.is_empty() {
continue;
}
state.snapshot = Arc::new(projected);
state.rtr_payloads = state.snapshot.rtr_payloads_for_rtr_arc();
Self::push_delta(
state,
self.max_delta,
self.prune_delta_by_snapshot_size,
delta,
);
changed_any = true;
}
self.last_update_end = DualTime::now();
if !changed_any {
return Ok(None);
}
info!(
"RTR cache applied update: serials={:?}, session_ids={:?}, delta_lengths={:?}",
self.serials(),
self.session_ids(),
self.delta_lengths()
);
Ok(Some(self.applied_update_with_windows()))
}
fn applied_update_with_clear(&self) -> AppliedUpdate {
let snapshots = std::array::from_fn(|idx| self.versions[idx].snapshot.clone());
let serials = std::array::from_fn(|idx| self.versions[idx].serial);
let session_ids = std::array::from_fn(|idx| self.versions[idx].session_id);
AppliedUpdate {
availability: self.availability,
snapshots,
serials,
session_ids,
deltas: [None, None, None],
delta_windows: [None, None, None],
clear_delta_windows: [true, true, true],
}
}
fn applied_update_with_windows(&self) -> AppliedUpdate {
let snapshots = std::array::from_fn(|idx| self.versions[idx].snapshot.clone());
let serials = std::array::from_fn(|idx| self.versions[idx].serial);
let session_ids = std::array::from_fn(|idx| self.versions[idx].session_id);
let deltas = std::array::from_fn(|idx| self.versions[idx].deltas.back().cloned());
let delta_windows = std::array::from_fn(|idx| Self::delta_window(&self.versions[idx]));
AppliedUpdate {
availability: self.availability,
snapshots,
serials,
session_ids,
deltas,
delta_windows,
clear_delta_windows: [false, false, false],
}
}
pub fn is_data_available(&self) -> bool {
self.availability == CacheAvailability::Ready
}
pub fn availability(&self) -> CacheAvailability {
self.availability
}
pub fn session_id_for_version(&self, version: u8) -> u16 {
self.versions[version_index(version)].session_id
}
pub fn session_ids(&self) -> SessionIds {
SessionIds::from_array(std::array::from_fn(|idx| self.versions[idx].session_id))
}
pub fn snapshot_for_version(&self, version: u8) -> Snapshot {
self.versions[version_index(version)].snapshot.as_ref().clone()
}
pub fn rtr_payloads_for_version(&self, version: u8) -> Arc<Vec<Payload>> {
self.versions[version_index(version)].rtr_payloads.clone()
}
pub fn serial_for_version(&self, version: u8) -> u32 {
self.versions[version_index(version)].serial
}
pub fn serials(&self) -> [u32; VERSION_COUNT] {
std::array::from_fn(|idx| self.versions[idx].serial)
}
pub fn delta_lengths(&self) -> [usize; VERSION_COUNT] {
std::array::from_fn(|idx| self.versions[idx].deltas.len())
}
pub fn timing(&self) -> Timing {
self.timing
}
pub fn last_update_begin(&self) -> DualTime {
self.last_update_begin.clone()
}
pub fn last_update_end(&self) -> DualTime {
self.last_update_end.clone()
}
pub fn created_at(&self) -> DualTime {
self.created_at.clone()
}
pub fn get_deltas_since_for_version(&self, version: u8, client_serial: u32) -> SerialResult {
let state = &self.versions[version_index(version)];
if client_serial == state.serial {
return SerialResult::UpToDate;
}
if matches!(
serial_cmp(client_serial, state.serial),
Some(Ordering::Greater) | None
) {
return SerialResult::ResetRequired;
}
let deltas = match collect_deltas_since(state, client_serial) {
Some(deltas) => deltas,
None => return SerialResult::ResetRequired,
};
if deltas.is_empty() {
return SerialResult::UpToDate;
}
let merged = merge_deltas_minimally(state.serial, &deltas);
if merged.is_empty() {
SerialResult::UpToDate
} else {
SerialResult::Delta(merged)
}
}
}
fn collect_deltas_since(state: &VersionState, client_serial: u32) -> Option<Vec<Arc<Delta>>> {
if state.deltas.is_empty() {
return None;
}
let oldest_serial = state.deltas.front().unwrap().serial();
let min_supported = oldest_serial.wrapping_sub(1);
if matches!(
serial_cmp(client_serial, min_supported),
Some(Ordering::Less) | None
) {
return None;
}
let mut result = Vec::new();
for delta in &state.deltas {
if serial_gt(delta.serial(), client_serial) {
result.push(delta.clone());
}
}
if let Some(first) = result.first() {
if first.serial() != client_serial.wrapping_add(1) {
return None;
}
}
Some(result)
}
fn merge_deltas_minimally(current_serial: u32, deltas: &[Arc<Delta>]) -> Delta {
let mut states = BTreeMap::<ChangeKey, LogicalStateRef<'_>>::new();
for delta in deltas {
for payload in delta.withdrawn() {
let key = change_key(payload);
let state = states.entry(key).or_insert_with(LogicalStateRef::new);
if state.before.is_none() && state.after.is_none() {
state.before = Some(payload);
}
state.after = None;
}
for payload in delta.announced() {
let key = change_key(payload);
let state = states.entry(key).or_insert_with(LogicalStateRef::new);
state.after = Some(payload);
}
}
let mut announced = Vec::new();
let mut withdrawn = Vec::new();
for (_key, state) in states {
match (state.before, state.after) {
(None, None) => {}
(None, Some(new_payload)) => announced.push(new_payload.clone()),
(Some(old_payload), None) => withdrawn.push(old_payload.clone()),
(Some(old_payload), Some(new_payload)) => {
if old_payload != new_payload {
if matches!(old_payload, Payload::Aspa(_))
&& matches!(new_payload, Payload::Aspa(_))
{
announced.push(new_payload.clone());
} else {
withdrawn.push(old_payload.clone());
announced.push(new_payload.clone());
}
}
}
}
}
Delta::new(current_serial, announced, withdrawn)
}
fn project_snapshot_for_version(snapshot: &Snapshot, version: u8) -> Snapshot {
snapshot.project_for_version(version)
}
fn estimate_snapshot_payload_wire_size(snapshot: &Snapshot) -> usize {
snapshot
.rtr_payloads_for_rtr_arc()
.iter()
.map(|payload| estimate_payload_wire_size(payload, true))
.sum()
}
fn estimate_delta_window_payload_wire_size(deltas: &VecDeque<Arc<Delta>>) -> usize {
deltas
.iter()
.map(|delta| estimate_delta_wire_size(delta))
.sum()
}
fn estimate_delta_wire_size(delta: &Delta) -> usize {
delta
.payload_updates_for_rtr()
.iter()
.map(|(announce, payload)| estimate_payload_wire_size(payload, *announce))
.sum()
}
fn estimate_payload_wire_size(payload: &Payload, announce: bool) -> usize {
match payload {
Payload::RouteOrigin(origin) => match origin.prefix().address {
crate::data_model::resources::ip_resources::IPAddress::V4(_) => 20,
crate::data_model::resources::ip_resources::IPAddress::V6(_) => 32,
},
Payload::RouterKey(key) => 8 + 20 + 4 + key.spki().len(),
Payload::Aspa(aspa) => {
let providers = if announce {
aspa.provider_asns().len()
} else {
0
};
8 + 4 + providers * 4
}
}
}
#[derive(Debug, Clone, Copy, Default)]
struct LogicalStateRef<'a> {
before: Option<&'a Payload>,
after: Option<&'a Payload>,
}
impl<'a> LogicalStateRef<'a> {
fn new() -> Self {
Self {
before: None,
after: None,
}
}
}
pub enum SerialResult {
UpToDate,
Delta(Delta),
ResetRequired,
}
#[derive(Clone)]
pub(super) struct AppliedUpdate {
pub(super) availability: CacheAvailability,
pub(super) snapshots: [Arc<Snapshot>; VERSION_COUNT],
pub(super) serials: [u32; VERSION_COUNT],
pub(super) session_ids: [u16; VERSION_COUNT],
pub(super) deltas: [Option<Arc<Delta>>; VERSION_COUNT],
pub(super) delta_windows: [Option<(u32, u32)>; VERSION_COUNT],
pub(super) clear_delta_windows: [bool; VERSION_COUNT],
}
fn serial_cmp(a: u32, b: u32) -> Option<Ordering> {
if a == b {
return Some(Ordering::Equal);
}
let diff = a.wrapping_sub(b);
if diff == SERIAL_HALF_RANGE {
None
} else if diff < SERIAL_HALF_RANGE {
Some(Ordering::Greater)
} else {
Some(Ordering::Less)
}
}
fn serial_gt(a: u32, b: u32) -> bool {
matches!(serial_cmp(a, b), Some(Ordering::Greater))
}
fn version_index(version: u8) -> usize {
match version {
0..=2 => version as usize,
_ => panic!("unsupported RTR protocol version: {}", version),
}
}