rpki/src/main_rtr.rs

637 lines
24 KiB
Rust

use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Result, anyhow};
use arc_swap::ArcSwap;
use chrono::{FixedOffset, Utc};
use tokio::task::JoinHandle;
use tracing::{info, warn};
use rpki::rtr::cache::{RtrCache, SharedRtrCache, Snapshot};
use rpki::rtr::payload::Timing;
use rpki::rtr::server::ssh::SshAuthMode;
use rpki::rtr::server::{RtrNotifier, RtrService, RtrServiceConfig, RunningRtrService};
use rpki::rtr::store::RtrStore;
use rpki::source::pipeline::{PayloadLoadConfig, load_payloads_from_latest_sources};
#[derive(Debug, Clone)]
struct AppConfig {
enable_tls: bool,
enable_ssh: bool,
tcp_addr: SocketAddr,
tls_addr: SocketAddr,
ssh_addr: SocketAddr,
db_path: String,
ccr_dir: String,
slurm_dir: Option<String>,
tls_cert_path: String,
tls_key_path: String,
tls_client_ca_path: String,
ssh_host_key_path: String,
ssh_authorized_keys_path: String,
ssh_username: String,
ssh_subsystem_name: String,
ssh_auth_mode: SshAuthMode,
ssh_password: Option<String>,
max_delta: u8,
prune_delta_by_snapshot_size: bool,
strict_ccr_validation: bool,
source_refresh_interval: Duration,
timing: Timing,
service_config: RtrServiceConfig,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
enable_tls: false,
enable_ssh: false,
tcp_addr: "0.0.0.0:323".parse().expect("invalid default tcp_addr"),
tls_addr: "0.0.0.0:324".parse().expect("invalid default tls_addr"),
ssh_addr: "0.0.0.0:22".parse().expect("invalid default ssh_addr"),
db_path: "./rtr-db".to_string(),
ccr_dir: "./data".to_string(),
slurm_dir: None,
tls_cert_path: "./certs/server.crt".to_string(),
tls_key_path: "./certs/server.key".to_string(),
tls_client_ca_path: "./certs/client-ca.crt".to_string(),
ssh_host_key_path: "./certs/ssh_host_ed25519_key".to_string(),
ssh_authorized_keys_path: "./certs/rtr-authorized_keys".to_string(),
ssh_username: "rpki-rtr".to_string(),
ssh_subsystem_name: "rpki-rtr".to_string(),
ssh_auth_mode: SshAuthMode::Key,
ssh_password: None,
max_delta: 100,
prune_delta_by_snapshot_size: false,
strict_ccr_validation: false,
source_refresh_interval: Duration::from_secs(300),
timing: Timing::default(),
service_config: RtrServiceConfig {
max_connections: 512,
max_concurrent_handshakes: 128,
notify_queue_size: 1024,
tcp_keepalive: Some(Duration::from_secs(60)),
warn_insecure_tcp: true,
require_tls_server_dns_name_san: false,
enforce_tls_client_san_ip_match: true,
},
}
}
}
impl AppConfig {
fn from_env() -> Result<Self> {
let mut config = Self::default();
// TLS and TCP
if let Some(value) = env_var("RPKI_RTR_ENABLE_TLS")? {
config.enable_tls = parse_bool(&value, "RPKI_RTR_ENABLE_TLS")?;
}
if let Some(value) = env_var("RPKI_RTR_ENABLE_SSH")? {
config.enable_ssh = parse_bool(&value, "RPKI_RTR_ENABLE_SSH")?;
}
if let Some(value) = env_var("RPKI_RTR_TCP_ADDR")? {
config.tcp_addr = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_TCP_ADDR '{}': {}", value, err))?;
}
if let Some(value) = env_var("RPKI_RTR_TLS_ADDR")? {
config.tls_addr = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_TLS_ADDR '{}': {}", value, err))?;
}
if let Some(value) = env_var("RPKI_RTR_SSH_ADDR")? {
config.ssh_addr = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_SSH_ADDR '{}': {}", value, err))?;
}
if let Some(value) = env_var("RPKI_RTR_SSH_PORT")? {
let port: u16 = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_SSH_PORT '{}': {}", value, err))?;
config.ssh_addr.set_port(port);
}
// data
if let Some(value) = env_var("RPKI_RTR_DB_PATH")? {
config.db_path = value;
}
if let Some(value) = env_var("RPKI_RTR_CCR_DIR")? {
config.ccr_dir = value;
}
if let Some(value) = env_var("RPKI_RTR_SLURM_DIR")? {
let value = value.trim();
config.slurm_dir = if value.is_empty() {
None
} else {
Some(value.to_string())
};
}
if let Some(value) = env_var("RPKI_RTR_TLS_CERT_PATH")? {
config.tls_cert_path = value;
}
if let Some(value) = env_var("RPKI_RTR_TLS_KEY_PATH")? {
config.tls_key_path = value;
}
if let Some(value) = env_var("RPKI_RTR_TLS_CLIENT_CA_PATH")? {
config.tls_client_ca_path = value;
}
if let Some(value) = env_var("RPKI_RTR_SSH_HOST_KEY_PATH")? {
config.ssh_host_key_path = value;
}
if let Some(value) = env_var("RPKI_RTR_SSH_AUTHORIZED_KEYS_PATH")? {
config.ssh_authorized_keys_path = value;
}
if let Some(value) = env_var("RPKI_RTR_SSH_USERNAME")? {
config.ssh_username = value;
}
if let Some(value) = env_var("RPKI_RTR_SSH_SUBSYSTEM_NAME")? {
config.ssh_subsystem_name = value;
}
if let Some(value) = env_var("RPKI_RTR_SSH_AUTH_MODE")? {
config.ssh_auth_mode = SshAuthMode::parse(&value).ok_or_else(|| {
anyhow!(
"invalid RPKI_RTR_SSH_AUTH_MODE '{}': expected key|password|both",
value
)
})?;
}
if let Some(value) = env_var("RPKI_RTR_SSH_PASSWORD")? {
let value = value.trim().to_string();
config.ssh_password = if value.is_empty() { None } else { Some(value) };
}
if let Some(value) = env_var("RPKI_RTR_MAX_DELTA")? {
let parsed: u8 = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_MAX_DELTA '{}': {}", value, err))?;
if parsed == 0 {
return Err(anyhow!(
"invalid RPKI_RTR_MAX_DELTA '{}': must be >= 1",
value
));
}
config.max_delta = parsed;
}
if let Some(value) = env_var("RPKI_RTR_PRUNE_DELTA_BY_SNAPSHOT_SIZE")? {
config.prune_delta_by_snapshot_size =
parse_bool(&value, "RPKI_RTR_PRUNE_DELTA_BY_SNAPSHOT_SIZE")?;
}
if let Some(value) = env_var("RPKI_RTR_STRICT_CCR_VALIDATION")? {
config.strict_ccr_validation = parse_bool(&value, "RPKI_RTR_STRICT_CCR_VALIDATION")?;
}
let source_refresh_interval_new = env_var("RPKI_RTR_SOURCE_REFRESH_INTERVAL_SECS")?;
let source_refresh_interval_legacy = env_var("RPKI_RTR_REFRESH_INTERVAL_SECS")?;
match (
source_refresh_interval_new.as_deref(),
source_refresh_interval_legacy.as_deref(),
) {
(Some(new_value), Some(_)) => {
let secs = parse_positive_u64(new_value, "RPKI_RTR_SOURCE_REFRESH_INTERVAL_SECS")?;
config.source_refresh_interval = Duration::from_secs(secs);
warn!(
"both RPKI_RTR_SOURCE_REFRESH_INTERVAL_SECS and legacy RPKI_RTR_REFRESH_INTERVAL_SECS are set; using RPKI_RTR_SOURCE_REFRESH_INTERVAL_SECS"
);
}
(Some(new_value), None) => {
let secs = parse_positive_u64(new_value, "RPKI_RTR_SOURCE_REFRESH_INTERVAL_SECS")?;
config.source_refresh_interval = Duration::from_secs(secs);
}
(None, Some(legacy_value)) => {
let secs = parse_positive_u64(legacy_value, "RPKI_RTR_REFRESH_INTERVAL_SECS")?;
config.source_refresh_interval = Duration::from_secs(secs);
warn!(
"RPKI_RTR_REFRESH_INTERVAL_SECS is deprecated; use RPKI_RTR_SOURCE_REFRESH_INTERVAL_SECS"
);
}
(None, None) => {}
}
if let Some(value) = env_var("RPKI_RTR_TIMING_REFRESH_SECS")? {
config.timing.refresh = parse_positive_u32(&value, "RPKI_RTR_TIMING_REFRESH_SECS")?;
}
if let Some(value) = env_var("RPKI_RTR_TIMING_RETRY_SECS")? {
config.timing.retry = parse_positive_u32(&value, "RPKI_RTR_TIMING_RETRY_SECS")?;
}
if let Some(value) = env_var("RPKI_RTR_TIMING_EXPIRE_SECS")? {
config.timing.expire = parse_positive_u32(&value, "RPKI_RTR_TIMING_EXPIRE_SECS")?;
}
config
.timing
.validate()
.map_err(|err| anyhow!("invalid RTR timing configuration: {}", err))?;
if let Some(value) = env_var("RPKI_RTR_MAX_CONNECTIONS")? {
config.service_config.max_connections = value
.parse()
.map_err(|err| anyhow!("invalid RPKI_RTR_MAX_CONNECTIONS '{}': {}", value, err))?;
}
if let Some(value) = env_var("RPKI_RTR_MAX_CONCURRENT_HANDSHAKES")? {
config.service_config.max_concurrent_handshakes = value.parse().map_err(|err| {
anyhow!(
"invalid RPKI_RTR_MAX_CONCURRENT_HANDSHAKES '{}': {}",
value,
err
)
})?;
}
if let Some(value) = env_var("RPKI_RTR_NOTIFY_QUEUE_SIZE")? {
config.service_config.notify_queue_size = value.parse().map_err(|err| {
anyhow!("invalid RPKI_RTR_NOTIFY_QUEUE_SIZE '{}': {}", value, err)
})?;
}
if let Some(value) = env_var("RPKI_RTR_TCP_KEEPALIVE_SECS")? {
let secs: u64 = value.parse().map_err(|err| {
anyhow!("invalid RPKI_RTR_TCP_KEEPALIVE_SECS '{}': {}", value, err)
})?;
config.service_config.tcp_keepalive = if secs == 0 {
None
} else {
Some(Duration::from_secs(secs))
};
}
if let Some(value) = env_var("RPKI_RTR_WARN_INSECURE_TCP")? {
config.service_config.warn_insecure_tcp =
parse_bool(&value, "RPKI_RTR_WARN_INSECURE_TCP")?;
}
if let Some(value) = env_var("RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN")? {
config.service_config.require_tls_server_dns_name_san =
parse_bool(&value, "RPKI_RTR_REQUIRE_TLS_SERVER_DNS_NAME_SAN")?;
}
if let Some(value) = env_var("RPKI_RTR_ENFORCE_TLS_CLIENT_SAN_IP_MATCH")? {
config.service_config.enforce_tls_client_san_ip_match =
parse_bool(&value, "RPKI_RTR_ENFORCE_TLS_CLIENT_SAN_IP_MATCH")?;
}
if config.service_config.max_connections == 0 {
return Err(anyhow!(
"invalid RPKI_RTR_MAX_CONNECTIONS '{}': must be >= 1",
config.service_config.max_connections
));
}
if config.service_config.max_concurrent_handshakes == 0 {
return Err(anyhow!(
"invalid RPKI_RTR_MAX_CONCURRENT_HANDSHAKES '{}': must be >= 1",
config.service_config.max_concurrent_handshakes
));
}
if config.service_config.max_concurrent_handshakes > config.service_config.max_connections {
return Err(anyhow!(
"invalid handshake/connection limits: RPKI_RTR_MAX_CONCURRENT_HANDSHAKES ({}) must be <= RPKI_RTR_MAX_CONNECTIONS ({})",
config.service_config.max_concurrent_handshakes,
config.service_config.max_connections
));
}
Ok(config)
}
}
#[tokio::main]
async fn main() -> Result<()> {
init_tracing();
let config = AppConfig::from_env()?;
log_startup_config(&config);
let store = open_store(&config)?;
let shared_cache = init_shared_cache(&config, &store)?;
let service = RtrService::with_config(shared_cache.clone(), config.service_config.clone());
let notifier = service.notifier();
let running = start_servers(&config, &service);
let refresh_task = spawn_refresh_task(&config, shared_cache.clone(), store.clone(), notifier);
wait_for_shutdown().await?;
running.shutdown();
running.wait().await;
refresh_task.abort();
let _ = refresh_task.await;
info!("RTR service stopped");
Ok(())
}
fn open_store(config: &AppConfig) -> Result<RtrStore> {
info!("opening RTR store: {}", config.db_path);
RtrStore::open(&config.db_path)
}
fn init_shared_cache(config: &AppConfig, store: &RtrStore) -> Result<SharedRtrCache> {
let payload_load_config = PayloadLoadConfig {
ccr_dir: config.ccr_dir.clone(),
slurm_dir: config.slurm_dir.clone(),
strict_ccr_validation: config.strict_ccr_validation,
};
let source_to_delta_started = Instant::now();
let initial_cache = RtrCache::default().init(
store,
config.max_delta,
config.prune_delta_by_snapshot_size,
config.timing,
|| {
let payloads = load_payloads_from_latest_sources(&payload_load_config)?;
info!(
"RTR source-to-delta timing: phase=startup_load_complete, ccr_dir={}, payload_count={}, elapsed_ms={}",
payload_load_config.ccr_dir,
payloads.len(),
source_to_delta_started.elapsed().as_millis()
);
Ok(payloads)
},
)?;
info!(
"RTR source-to-delta timing: phase=startup_cache_init_complete, ccr_dir={}, serials={:?}, elapsed_ms={}",
payload_load_config.ccr_dir,
initial_cache.serials(),
source_to_delta_started.elapsed().as_millis()
);
let shared_cache: SharedRtrCache = Arc::new(ArcSwap::from_pointee(initial_cache));
let cache = shared_cache.load_full();
info!(
"cache initialized: session_ids={:?}, serials={:?}",
cache.session_ids(),
cache.serials()
);
Ok(shared_cache)
}
fn start_servers(config: &AppConfig, service: &RtrService) -> RunningRtrService {
if config.enable_tls && config.enable_ssh {
info!("starting TCP, TLS and SSH RTR servers");
service.spawn_tcp_tls_and_ssh_from_pem_and_openssh(
config.tcp_addr,
config.tls_addr,
config.ssh_addr,
&config.tls_cert_path,
&config.tls_key_path,
&config.tls_client_ca_path,
&config.ssh_host_key_path,
&config.ssh_authorized_keys_path,
&config.ssh_username,
&config.ssh_subsystem_name,
config.ssh_auth_mode,
config.ssh_password.as_deref(),
)
} else if config.enable_tls {
info!("starting TCP and TLS RTR servers");
service.spawn_tcp_and_tls_from_pem(
config.tcp_addr,
config.tls_addr,
&config.tls_cert_path,
&config.tls_key_path,
&config.tls_client_ca_path,
)
} else if config.enable_ssh {
info!("starting TCP and SSH RTR servers");
service.spawn_tcp_and_ssh_from_openssh(
config.tcp_addr,
config.ssh_addr,
&config.ssh_host_key_path,
&config.ssh_authorized_keys_path,
&config.ssh_username,
&config.ssh_subsystem_name,
config.ssh_auth_mode,
config.ssh_password.as_deref(),
)
} else {
info!("starting TCP RTR server");
service.spawn_tcp_only(config.tcp_addr)
}
}
fn spawn_refresh_task(
config: &AppConfig,
shared_cache: SharedRtrCache,
store: RtrStore,
notifier: RtrNotifier,
) -> JoinHandle<()> {
let refresh_interval = config.source_refresh_interval;
let payload_load_config = PayloadLoadConfig {
ccr_dir: config.ccr_dir.clone(),
slurm_dir: config.slurm_dir.clone(),
strict_ccr_validation: config.strict_ccr_validation,
};
tokio::spawn(async move {
let mut interval = tokio::time::interval(refresh_interval);
loop {
interval.tick().await;
let source_to_delta_started = Instant::now();
match load_payloads_from_latest_sources(&payload_load_config) {
Ok(payloads) => {
let (payload_count, updated) = {
let payload_count = payloads.len();
let source_snapshot = Snapshot::from_payloads(payloads);
let old_cache = shared_cache.load_full();
let old_serial = old_cache.serial_for_version(2);
let mut next_cache = old_cache.as_ref().clone();
let updated = match next_cache.update_with_snapshot(source_snapshot, &store) {
Ok(()) => {
let new_serial = next_cache.serial_for_version(2);
shared_cache.store(std::sync::Arc::new(next_cache));
if new_serial != old_serial {
info!(
"RTR cache refresh applied: ccr_dir={}, payload_count={}, old_serial={}, new_serial={}",
payload_load_config.ccr_dir,
payload_count,
old_serial,
new_serial
);
true
} else {
info!(
"RTR cache refresh found no change: ccr_dir={}, payload_count={}, serial={}",
payload_load_config.ccr_dir, payload_count, old_serial
);
false
}
}
Err(err) => {
warn!("RTR cache update failed: {:?}", err);
false
}
};
(payload_count, updated)
};
info!(
"RTR source-to-delta timing: phase=refresh_cache_update_complete, ccr_dir={}, payload_count={}, changed={}, elapsed_ms={}",
payload_load_config.ccr_dir,
payload_count,
updated,
source_to_delta_started.elapsed().as_millis()
);
if updated {
notifier.notify_cache_updated();
info!("RTR cache updated, notify signal emitted (session may skip SerialNotify due to rate limit)");
}
}
Err(err) => {
warn!(
"failed to reload CCR/SLURM payloads from {}: {:?} (source_to_delta_elapsed_ms={})",
payload_load_config.ccr_dir,
err,
source_to_delta_started.elapsed().as_millis()
);
}
}
}
})
}
async fn wait_for_shutdown() -> Result<()> {
tokio::signal::ctrl_c().await?;
info!("shutdown signal received");
Ok(())
}
fn log_startup_config(config: &AppConfig) {
info!("starting RTR service");
info!("db_path={}", config.db_path);
info!("tcp_addr={}", config.tcp_addr);
info!("tls_enabled={}", config.enable_tls);
info!("ssh_enabled={}", config.enable_ssh);
if config.enable_tls {
info!("tls_addr={}", config.tls_addr);
info!("tls_cert_path={}", config.tls_cert_path);
info!("tls_key_path={}", config.tls_key_path);
info!("tls_client_ca_path={}", config.tls_client_ca_path);
}
if config.enable_ssh {
info!("ssh_addr={}", config.ssh_addr);
info!("ssh_host_key_path={}", config.ssh_host_key_path);
info!(
"ssh_authorized_keys_path={}",
config.ssh_authorized_keys_path
);
info!("ssh_username={}", config.ssh_username);
info!("ssh_subsystem_name={}", config.ssh_subsystem_name);
info!("ssh_auth_mode={}", config.ssh_auth_mode.as_str());
info!("ssh_password_enabled={}", config.ssh_password.is_some());
}
info!("ccr_dir={}", config.ccr_dir);
info!(
"slurm_dir={}",
config.slurm_dir.as_deref().unwrap_or("disabled")
);
info!("max_delta={}", config.max_delta);
info!("strict_ccr_validation={}", config.strict_ccr_validation);
info!(
"source_refresh_interval_secs={}",
config.source_refresh_interval.as_secs()
);
info!("rtr_timing_refresh_secs={}", config.timing.refresh);
info!("rtr_timing_retry_secs={}", config.timing.retry);
info!("rtr_timing_expire_secs={}", config.timing.expire);
info!("max_connections={}", config.service_config.max_connections);
info!(
"max_concurrent_handshakes={}",
config.service_config.max_concurrent_handshakes
);
info!(
"notify_queue_size={}",
config.service_config.notify_queue_size
);
info!(
"tcp_keepalive_secs={}",
config
.service_config
.tcp_keepalive
.map(|duration| duration.as_secs().to_string())
.unwrap_or_else(|| "disabled".to_string())
);
info!(
"warn_insecure_tcp={}",
config.service_config.warn_insecure_tcp
);
info!(
"require_tls_server_dns_name_san={}",
config.service_config.require_tls_server_dns_name_san
);
info!(
"enforce_tls_client_san_ip_match={}",
config.service_config.enforce_tls_client_san_ip_match
);
}
fn init_tracing() {
let filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("warn"));
struct ShanghaiTimer;
impl tracing_subscriber::fmt::time::FormatTime for ShanghaiTimer {
fn format_time(
&self,
w: &mut tracing_subscriber::fmt::format::Writer<'_>,
) -> std::fmt::Result {
let shanghai_offset = FixedOffset::east_opt(8 * 60 * 60)
.expect("fixed +08:00 offset should always be valid");
let now = Utc::now().with_timezone(&shanghai_offset);
write!(w, "{}", now.format("%Y-%m-%d %H:%M:%S%.3f %:z"))
}
}
if let Err(err) = tracing_subscriber::fmt()
.with_timer(ShanghaiTimer)
.with_env_filter(filter)
.with_target(true)
.with_thread_ids(true)
.with_level(true)
.try_init()
{
eprintln!("failed to initialize tracing subscriber: {err}");
}
}
fn env_var(name: &str) -> Result<Option<String>> {
match env::var(name) {
Ok(value) => Ok(Some(value)),
Err(env::VarError::NotPresent) => Ok(None),
Err(err) => Err(anyhow!("failed to read {}: {}", name, err)),
}
}
fn parse_bool(value: &str, name: &str) -> Result<bool> {
match value.trim().to_ascii_lowercase().as_str() {
"1" | "true" | "yes" | "on" => Ok(true),
"0" | "false" | "no" | "off" => Ok(false),
_ => Err(anyhow!("invalid {} '{}': expected boolean", name, value)),
}
}
fn parse_positive_u64(value: &str, name: &str) -> Result<u64> {
let parsed = value
.parse::<u64>()
.map_err(|err| anyhow!("invalid {} '{}': {}", name, value, err))?;
if parsed == 0 {
return Err(anyhow!("invalid {} '{}': must be >= 1", name, value));
}
Ok(parsed)
}
fn parse_positive_u32(value: &str, name: &str) -> Result<u32> {
let parsed = value
.parse::<u32>()
.map_err(|err| anyhow!("invalid {} '{}': {}", name, value, err))?;
if parsed == 0 {
return Err(anyhow!("invalid {} '{}': must be >= 1", name, value));
}
Ok(parsed)
}