rpki/src/main_rtr.rs
2026-06-16 14:05:26 +08:00

427 lines
16 KiB
Rust

use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use arc_swap::ArcSwap;
use chrono::Utc;
use tokio::task::JoinHandle;
use tracing::{info, warn};
use rpki::rtr::cache::{RtrCache, SharedRtrCache, Snapshot};
use rpki::rtr::config::{AppConfig, log_startup_config};
use rpki::rtr::report::{ReportConfiguration, ReportContext, current_rss_mib};
use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceStats, RunningRtrService};
use rpki::rtr::store::RtrStore;
use rpki::source::pipeline::{
PayloadLoadConfig, SourceFingerprint, latest_sources_fingerprint,
load_payloads_from_latest_sources_with_report,
};
#[tokio::main]
async fn main() -> Result<()> {
let config = AppConfig::from_env()?;
init_tracing(config.timezone);
log_startup_config(&config);
let report_context = ReportContext::new(ReportConfiguration::new(
config.source_refresh_interval.as_secs(),
config.report_interval.as_secs(),
config.max_delta,
config.prune_delta_by_snapshot_size,
config.strict_ccr_validation,
config.timezone,
(
config.timing.refresh,
config.timing.retry,
config.timing.expire,
),
));
let store = open_store(&config)?;
let shared_cache = init_shared_cache(&config, &store, &report_context)?;
let service = RtrService::with_config(shared_cache.clone(), config.service_config.clone());
let notifier = service.notifier();
let service_stats = service.stats();
let running = start_servers(&config, &service);
let refresh_task = spawn_refresh_task(
&config,
shared_cache.clone(),
store.clone(),
notifier,
service_stats,
report_context,
);
wait_for_shutdown().await?;
running.shutdown();
running.wait().await;
refresh_task.abort();
let _ = refresh_task.await;
info!("RTR service stopped");
Ok(())
}
fn open_store(config: &AppConfig) -> Result<RtrStore> {
info!("opening RTR store: {}", config.db_path);
RtrStore::open(&config.db_path)
}
fn init_shared_cache(
config: &AppConfig,
store: &RtrStore,
report_context: &ReportContext,
) -> Result<SharedRtrCache> {
let payload_load_config = PayloadLoadConfig {
ccr_dir: config.ccr_dir.clone(),
slurm_dir: config.slurm_dir.clone(),
strict_ccr_validation: config.strict_ccr_validation,
};
let source_to_delta_started = Instant::now();
let report_context_for_loader = report_context.clone();
let initial_cache = RtrCache::default().init(
store,
config.max_delta,
config.prune_delta_by_snapshot_size,
config.timing,
|| {
let load = load_payloads_from_latest_sources_with_report(&payload_load_config)?;
info!(
"RTR source-to-delta timing: phase=startup_load_complete, ccr_dir={}, payload_count={}, elapsed_ms={}",
payload_load_config.ccr_dir,
load.payloads.len(),
source_to_delta_started.elapsed().as_millis()
);
report_context_for_loader.record_refresh_success(
Utc::now(),
source_to_delta_started.elapsed().as_millis(),
true,
load.source,
load.quality,
);
Ok(load.payloads)
},
)?;
info!(
"RTR source-to-delta timing: phase=startup_cache_init_complete, ccr_dir={}, serials={:?}, elapsed_ms={}",
payload_load_config.ccr_dir,
initial_cache.serials(),
source_to_delta_started.elapsed().as_millis()
);
let shared_cache: SharedRtrCache = Arc::new(ArcSwap::from_pointee(initial_cache));
let cache = shared_cache.load_full();
info!(
"cache initialized: session_ids={:?}, serials={:?}",
cache.session_ids(),
cache.serials()
);
Ok(shared_cache)
}
fn start_servers(config: &AppConfig, service: &RtrService) -> RunningRtrService {
if config.enable_tls && config.enable_ssh {
info!("starting TCP, TLS and SSH RTR servers");
service.spawn_tcp_tls_and_ssh_from_pem_and_openssh(
config.tcp_addr,
config.tls_addr,
config.ssh_addr,
&config.tls_cert_path,
&config.tls_key_path,
&config.tls_client_ca_path,
&config.ssh_host_key_path,
&config.ssh_authorized_keys_path,
&config.ssh_username,
&config.ssh_subsystem_name,
config.ssh_auth_mode,
config.ssh_password.as_deref(),
)
} else if config.enable_tls {
info!("starting TCP and TLS RTR servers");
service.spawn_tcp_and_tls_from_pem(
config.tcp_addr,
config.tls_addr,
&config.tls_cert_path,
&config.tls_key_path,
&config.tls_client_ca_path,
)
} else if config.enable_ssh {
info!("starting TCP and SSH RTR servers");
service.spawn_tcp_and_ssh_from_openssh(
config.tcp_addr,
config.ssh_addr,
&config.ssh_host_key_path,
&config.ssh_authorized_keys_path,
&config.ssh_username,
&config.ssh_subsystem_name,
config.ssh_auth_mode,
config.ssh_password.as_deref(),
)
} else {
info!("starting TCP RTR server");
service.spawn_tcp_only(config.tcp_addr)
}
}
fn spawn_refresh_task(
config: &AppConfig,
shared_cache: SharedRtrCache,
store: RtrStore,
notifier: RtrNotifier,
service_stats: RtrServiceStats,
report_context: ReportContext,
) -> JoinHandle<()> {
let refresh_interval = config.source_refresh_interval;
let report_interval = config.report_interval;
let report_dir = PathBuf::from(&config.report_dir);
let payload_load_config = PayloadLoadConfig {
ccr_dir: config.ccr_dir.clone(),
slurm_dir: config.slurm_dir.clone(),
strict_ccr_validation: config.strict_ccr_validation,
};
tokio::spawn(async move {
let mut interval = tokio::time::interval(refresh_interval);
let mut last_fingerprint: Option<SourceFingerprint> = None;
report_context.write_or_warn(
&report_dir,
"startup",
&shared_cache,
&notifier,
&service_stats,
);
let mut stats_interval = tokio::time::interval_at(
tokio::time::Instant::now() + report_interval,
report_interval,
);
stats_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = stats_interval.tick() => {
log_cache_memory_stats("periodic_observe", &shared_cache, &notifier);
report_context.write_or_warn(&report_dir, "periodic", &shared_cache, &notifier, &service_stats);
continue;
}
_ = interval.tick() => {}
}
let source_to_delta_started = Instant::now();
let attempted_at = Utc::now();
let current_fingerprint = match latest_sources_fingerprint(&payload_load_config) {
Ok(fp) => fp,
Err(err) => {
report_context.record_refresh_failure(
attempted_at,
source_to_delta_started.elapsed().as_millis(),
&err,
);
warn!(
"failed to fingerprint CCR/SLURM sources from {}: {:?} (source_to_delta_elapsed_ms={})",
payload_load_config.ccr_dir,
err,
source_to_delta_started.elapsed().as_millis()
);
report_context.write_or_warn(
&report_dir,
"refresh_failed",
&shared_cache,
&notifier,
&service_stats,
);
continue;
}
};
if last_fingerprint.as_ref() == Some(&current_fingerprint) {
report_context.record_refresh_unchanged(
attempted_at,
source_to_delta_started.elapsed().as_millis(),
);
info!(
"RTR source refresh skipped: source files unchanged (ccr_path={}, slurm_file_count={}, elapsed_ms={})",
current_fingerprint.ccr.path,
current_fingerprint.slurm_files.len(),
source_to_delta_started.elapsed().as_millis()
);
log_cache_memory_stats("refresh_skipped_unchanged", &shared_cache, &notifier);
report_context.write_or_warn(
&report_dir,
"refresh_skipped_unchanged",
&shared_cache,
&notifier,
&service_stats,
);
continue;
}
match load_payloads_from_latest_sources_with_report(&payload_load_config) {
Ok(load) => {
let source = load.source;
let quality = load.quality;
let payloads = load.payloads;
let (payload_count, updated) = {
let payload_count = payloads.len();
let source_snapshot = Snapshot::from_payloads(payloads);
let old_cache = shared_cache.load_full();
let old_serial = old_cache.serial_for_version(2);
let mut next_cache = old_cache.as_ref().clone();
let updated = match next_cache.update_with_snapshot(source_snapshot, &store)
{
Ok(()) => {
let new_serial = next_cache.serial_for_version(2);
shared_cache.store(std::sync::Arc::new(next_cache));
if new_serial != old_serial {
info!(
"RTR cache refresh applied: ccr_dir={}, payload_count={}, old_serial={}, new_serial={}",
payload_load_config.ccr_dir,
payload_count,
old_serial,
new_serial
);
true
} else {
info!(
"RTR cache refresh found no change: ccr_dir={}, payload_count={}, serial={}",
payload_load_config.ccr_dir, payload_count, old_serial
);
false
}
}
Err(err) => {
report_context.record_refresh_failure(
attempted_at,
source_to_delta_started.elapsed().as_millis(),
&err,
);
warn!("RTR cache update failed: {:?}", err);
report_context.write_or_warn(
&report_dir,
"refresh_failed",
&shared_cache,
&notifier,
&service_stats,
);
continue;
}
};
(payload_count, updated)
};
report_context.record_refresh_success(
attempted_at,
source_to_delta_started.elapsed().as_millis(),
updated,
source,
quality,
);
info!(
"RTR source-to-delta timing: phase=refresh_cache_update_complete, ccr_dir={}, payload_count={}, changed={}, elapsed_ms={}",
payload_load_config.ccr_dir,
payload_count,
updated,
source_to_delta_started.elapsed().as_millis()
);
if updated {
let listener_count = notifier.notify_cache_updated();
info!(
"RTR cache updated, notify signal emitted to session listeners: listener_count={}",
listener_count
);
}
log_cache_memory_stats("refresh_complete", &shared_cache, &notifier);
report_context.write_or_warn(
&report_dir,
"refresh_complete",
&shared_cache,
&notifier,
&service_stats,
);
last_fingerprint = Some(current_fingerprint);
}
Err(err) => {
report_context.record_refresh_failure(
attempted_at,
source_to_delta_started.elapsed().as_millis(),
&err,
);
warn!(
"failed to reload CCR/SLURM payloads from {}: {:?} (source_to_delta_elapsed_ms={})",
payload_load_config.ccr_dir,
err,
source_to_delta_started.elapsed().as_millis()
);
report_context.write_or_warn(
&report_dir,
"refresh_failed",
&shared_cache,
&notifier,
&service_stats,
);
}
}
}
})
}
fn log_cache_memory_stats(phase: &str, shared_cache: &SharedRtrCache, notifier: &RtrNotifier) {
let cache = shared_cache.load_full();
let stats = cache.memory_stats();
let rss_mib = current_rss_mib();
info!(
"RTR memory observe: phase={}, listener_count={}, serials={:?}, snapshot_payload_counts={:?}, delta_lengths={:?}, delta_payload_counts={:?}, snapshot_arc_strong_counts={:?}, rtr_payloads_arc_strong_counts={:?}, rss_mib={:?}",
phase,
notifier.listener_count(),
stats.serials,
stats.snapshot_payload_counts,
stats.delta_lengths,
stats.delta_payload_counts,
stats.snapshot_arc_strong_counts,
stats.rtr_payloads_arc_strong_counts,
rss_mib
);
}
async fn wait_for_shutdown() -> Result<()> {
tokio::signal::ctrl_c().await?;
info!("shutdown signal received");
Ok(())
}
fn init_tracing(timezone: chrono_tz::Tz) {
let filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("warn"));
struct LocalTimer {
timezone: chrono_tz::Tz,
}
impl tracing_subscriber::fmt::time::FormatTime for LocalTimer {
fn format_time(
&self,
w: &mut tracing_subscriber::fmt::format::Writer<'_>,
) -> std::fmt::Result {
let now = Utc::now().with_timezone(&self.timezone);
write!(w, "{}", now.format("%Y-%m-%d %H:%M:%S%.3f %:z"))
}
}
if let Err(err) = tracing_subscriber::fmt()
.with_timer(LocalTimer { timezone })
.with_env_filter(filter)
.with_target(true)
.with_thread_ids(true)
.with_level(true)
.try_init()
{
eprintln!("failed to initialize tracing subscriber: {err}");
}
}