use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; use std::time::Duration; use anyhow::{Context, Result, anyhow}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::sleep; use tracing::{info, warn}; use crate::rtr::config::{RuntimeConfig, RuntimeConfigPatch}; use crate::slurm::admin::{ SlurmAdmin, SlurmFileActionRequest, SlurmFileOperationResult, SlurmFileWriteRequest, parse_reload_query, }; #[derive(Clone)] pub struct RuntimeConfigHandle { current: Arc>, tx: watch::Sender, } #[derive(Clone)] pub struct SourceReloadHandle { tx: mpsc::Sender, } impl SourceReloadHandle { pub fn new(tx: mpsc::Sender) -> Self { Self { tx } } pub async fn reload(&self, phase: &'static str, force: bool) -> Result { let (respond_to, response) = oneshot::channel(); self.tx .send(SourceReloadCommand { phase, force, respond_to, }) .await .map_err(|_| anyhow!("source reload task is not available"))?; response .await .map_err(|_| anyhow!("source reload task dropped the response"))? .map_err(anyhow::Error::msg) } } pub struct SourceReloadCommand { pub phase: &'static str, pub force: bool, pub respond_to: oneshot::Sender>, } #[derive(Debug, Clone, Serialize)] pub struct SourceReloadResult { pub phase: &'static str, pub changed: bool, pub skipped_unchanged: bool, pub payload_count: Option, pub serials: [u32; 3], } #[derive(Clone)] pub struct AdminState { runtime_config: RuntimeConfigHandle, source_reload: Option, slurm_admin: Option, log_tail: LogTailConfig, } impl AdminState { pub fn new( runtime_config: RuntimeConfigHandle, source_reload: Option, slurm_admin: Option, log_tail: LogTailConfig, ) -> Self { Self { runtime_config, source_reload, slurm_admin, log_tail, } } } #[derive(Clone)] pub struct LogTailConfig { dir: PathBuf, name: String, } impl LogTailConfig { pub fn from_env() -> Self { let dir = std::env::var_os("RPKI_RTR_LOG_DIR") .map(PathBuf::from) .unwrap_or_else(|| PathBuf::from("/app/logs")); let name = std::env::var("RPKI_RTR_LOG_NAME") .or_else(|_| std::env::var("HOSTNAME")) .unwrap_or_else(|_| "rpki-rtr".to_string()); Self { dir, name } } fn path_for(&self, stream: LogStream) -> PathBuf { self.dir .join(format!("{}.{}.log", self.name, stream.as_str())) } } impl RuntimeConfigHandle { pub fn new(config: RuntimeConfig) -> Self { let (tx, _) = watch::channel(config.clone()); Self { current: Arc::new(RwLock::new(config)), tx, } } pub fn current(&self) -> RuntimeConfig { self.current .read() .unwrap_or_else(|poisoned| poisoned.into_inner()) .clone() } pub fn subscribe(&self) -> watch::Receiver { self.tx.subscribe() } pub fn apply_patch(&self, patch: RuntimeConfigPatch) -> Result { let next = self.current().apply_patch(patch)?; { let mut current = self .current .write() .unwrap_or_else(|poisoned| poisoned.into_inner()); *current = next.clone(); } let _ = self.tx.send_replace(next.clone()); Ok(next) } } pub fn spawn_admin_config_server( addr: SocketAddr, token: Option, state: AdminState, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { if let Err(err) = run_admin_config_server(addr, token, state).await { warn!("RTR admin config server exited: {:?}", err); } }) } async fn run_admin_config_server( addr: SocketAddr, token: Option, state: AdminState, ) -> Result<()> { if token.is_none() && !addr.ip().is_loopback() { return Err(anyhow!( "RPKI_RTR_ADMIN_TOKEN is required when admin addr is not loopback: {}", addr )); } let listener = TcpListener::bind(addr) .await .with_context(|| format!("bind RTR admin config server on {addr}"))?; info!("RTR admin config server listening on {}", addr); loop { let (stream, peer_addr) = listener.accept().await?; let token = token.clone(); let state = state.clone(); tokio::spawn(async move { if let Err(err) = handle_admin_connection(stream, peer_addr, token, state).await { warn!( "RTR admin config request failed from {}: {:?}", peer_addr, err ); } }); } } async fn handle_admin_connection( mut stream: TcpStream, peer_addr: SocketAddr, token: Option, state: AdminState, ) -> Result<()> { let mut buffer = vec![0u8; 64 * 1024]; let mut read = 0usize; let header_end = loop { if read == buffer.len() { write_response(&mut stream, 413, "payload too large", "text/plain").await?; return Ok(()); } let n = stream.read(&mut buffer[read..]).await?; if n == 0 { return Ok(()); } read += n; if let Some(pos) = find_header_end(&buffer[..read]) { break pos; } }; let header = std::str::from_utf8(&buffer[..header_end]) .map_err(|err| anyhow!("invalid HTTP header from {}: {}", peer_addr, err))?; let request = parse_request_header(header)?; if let Some(token) = token.as_deref() { let expected = format!("Bearer {token}"); if request.authorization.as_deref() != Some(expected.as_str()) { write_response(&mut stream, 401, "unauthorized", "text/plain").await?; return Ok(()); } } let content_length = request.content_length.unwrap_or(0); let max_body_bytes = if request.path.starts_with("/admin/rtr/slurm/") { state .slurm_admin .as_ref() .map(SlurmAdmin::max_body_bytes) .unwrap_or(32 * 1024) } else { 32 * 1024 }; if content_length > max_body_bytes { write_response(&mut stream, 413, "payload too large", "text/plain").await?; return Ok(()); } let body_start = header_end + 4; let available_body = read.saturating_sub(body_start); let mut body = Vec::with_capacity(content_length); body.extend_from_slice(&buffer[body_start..read]); if available_body < content_length { let remaining = content_length - available_body; let mut tail = vec![0u8; remaining]; stream.read_exact(&mut tail).await?; body.extend_from_slice(&tail); } body.truncate(content_length); route_admin_request(&mut stream, request, body, state, peer_addr).await } async fn route_admin_request( stream: &mut TcpStream, request: RequestHeader, body: Vec, state: AdminState, peer_addr: SocketAddr, ) -> Result<()> { if request.method == "POST" && request.path == "/admin/rtr/config" { let patch = match serde_json::from_slice::(&body) { Ok(patch) => patch, Err(err) => { let message = format!("invalid json: {err}"); write_response(stream, 400, &message, "text/plain").await?; return Ok(()); } }; match state.runtime_config.apply_patch(patch) { Ok(config) => { info!("RTR admin config updated from {}: {:?}", peer_addr, config); let json = serde_json::to_string_pretty(&AdminConfigResponse { status: "ok", config, })?; write_response(stream, 200, &json, "application/json").await?; } Err(err) => { let message = format!("invalid config: {err}"); write_response(stream, 400, &message, "text/plain").await?; } } return Ok(()); } if request.method == "GET" && request.path == "/admin/rtr/health" { let json = serde_json::to_string_pretty(&AdminHealthResponse { status: "ok", config_api: true, source_reload_api: state.source_reload.is_some(), slurm_api: state.slurm_admin.is_some(), logs_api: true, })?; write_response(stream, 200, &json, "application/json").await?; return Ok(()); } if request.method == "GET" && request.path == "/admin/rtr/logs/tail" { return tail_log_stream(stream, request.query.as_deref(), &state.log_tail).await; } if request.method == "GET" && request.path == "/admin/rtr/config" { let json = serde_json::to_string_pretty(&AdminConfigResponse { status: "ok", config: state.runtime_config.current(), })?; write_response(stream, 200, &json, "application/json").await?; return Ok(()); } if request.method == "POST" && request.path == "/admin/rtr/slurm/reload" { let reload = reload_source(&state, "admin_slurm_reload", true).await; return write_json_or_error(stream, reload).await; } if request.path == "/admin/rtr/slurm/files" && request.method == "GET" { let Some(slurm_admin) = state.slurm_admin.as_ref() else { write_response(stream, 400, "SLURM admin is disabled", "text/plain").await?; return Ok(()); }; let json = serde_json::to_string_pretty(&SlurmFileListResponse { status: "ok", files: slurm_admin.list_files()?, })?; write_response(stream, 200, &json, "application/json").await?; return Ok(()); } if request.path == "/admin/rtr/slurm/files" && request.method == "POST" { let Some(slurm_admin) = state.slurm_admin.as_ref() else { write_response(stream, 400, "SLURM admin is disabled", "text/plain").await?; return Ok(()); }; let request_body = parse_json::(stream, &body).await?; let Some(name) = request_body.name.as_deref() else { write_response(stream, 400, "missing SLURM file name", "text/plain").await?; return Ok(()); }; let reload = request_body.reload.unwrap_or(false); let operation = slurm_admin.put_file(name, &request_body.content, "create_or_update"); return apply_slurm_operation(stream, &state, operation, reload).await; } let slurm_file_prefix = "/admin/rtr/slurm/files/"; if request.path.starts_with(slurm_file_prefix) { let rest = request.path[slurm_file_prefix.len()..].to_string(); return route_slurm_file_request(stream, request, body, state, rest).await; } write_response(stream, 404, "not found", "text/plain").await?; Ok(()) } async fn route_slurm_file_request( stream: &mut TcpStream, request: RequestHeader, body: Vec, state: AdminState, rest: String, ) -> Result<()> { let Some(slurm_admin) = state.slurm_admin.as_ref() else { write_response(stream, 400, "SLURM admin is disabled", "text/plain").await?; return Ok(()); }; if request.method == "GET" { let file = match slurm_admin.read_file(&rest) { Ok(file) => file, Err(err) => { let message = format!("invalid SLURM file request: {err}"); write_response(stream, 400, &message, "text/plain").await?; return Ok(()); } }; let json = serde_json::to_string_pretty(&SlurmFileContentResponse { status: "ok", file })?; write_response(stream, 200, &json, "application/json").await?; return Ok(()); } if request.method == "PUT" { let request_body = parse_json::(stream, &body).await?; let reload = parse_reload_query(request.query.as_deref(), request_body.reload); let operation = slurm_admin.put_file(&rest, &request_body.content, "create_or_update"); return apply_slurm_operation(stream, &state, operation, reload).await; } if request.method == "DELETE" { let reload = if body.is_empty() { parse_reload_query(request.query.as_deref(), None) } else { let request_body = parse_json::(stream, &body).await?; parse_reload_query(request.query.as_deref(), request_body.reload) }; let operation = slurm_admin.delete_file(&rest); return apply_slurm_operation(stream, &state, operation, reload).await; } if request.method == "POST" { if let Some(name) = rest.strip_suffix("/enable") { let request_body: SlurmFileActionRequest = parse_optional_json(stream, &body).await?; let reload = parse_reload_query(request.query.as_deref(), request_body.reload); let operation = slurm_admin.enable_file(name); return apply_slurm_operation(stream, &state, operation, reload).await; } if let Some(name) = rest.strip_suffix("/disable") { let request_body: SlurmFileActionRequest = parse_optional_json(stream, &body).await?; let reload = parse_reload_query(request.query.as_deref(), request_body.reload); let operation = slurm_admin.disable_file(name); return apply_slurm_operation(stream, &state, operation, reload).await; } } write_response(stream, 404, "not found", "text/plain").await?; Ok(()) } async fn apply_slurm_operation( stream: &mut TcpStream, state: &AdminState, operation: Result, reload: bool, ) -> Result<()> { let operation = match operation { Ok(operation) => operation, Err(err) => { let message = format!("invalid SLURM operation: {err}"); write_response(stream, 400, &message, "text/plain").await?; return Ok(()); } }; if !reload { let json = serde_json::to_string_pretty(&SlurmOperationResponse { status: "ok", operation: operation.result, reload: None, rollback: None, })?; write_response(stream, 200, &json, "application/json").await?; return Ok(()); } match reload_source(state, "admin_slurm_changed", true).await { Ok(result) => { let json = serde_json::to_string_pretty(&SlurmOperationResponse { status: "ok", operation: operation.result, reload: Some(result), rollback: None, })?; write_response(stream, 200, &json, "application/json").await?; } Err(err) => { let rollback = match operation.rollback() { Ok(()) => match reload_source(state, "admin_slurm_rollback", true).await { Ok(result) => Some(RollbackReport { status: "ok".to_string(), reload: Some(result), error: None, }), Err(reload_err) => Some(RollbackReport { status: "reload_failed".to_string(), reload: None, error: Some(reload_err.to_string()), }), }, Err(rollback_err) => Some(RollbackReport { status: "failed".to_string(), reload: None, error: Some(rollback_err.to_string()), }), }; let json = serde_json::to_string_pretty(&SlurmOperationErrorResponse { status: "reload_failed", error: err.to_string(), rollback, })?; write_response(stream, 400, &json, "application/json").await?; } } Ok(()) } async fn reload_source( state: &AdminState, phase: &'static str, force: bool, ) -> Result { let Some(source_reload) = state.source_reload.as_ref() else { return Err(anyhow!("source reload is not available")); }; source_reload.reload(phase, force).await } async fn parse_json Deserialize<'de>>( stream: &mut TcpStream, body: &[u8], ) -> Result { match serde_json::from_slice::(body) { Ok(value) => Ok(value), Err(err) => { let message = format!("invalid json: {err}"); write_response(stream, 400, &message, "text/plain").await?; Err(anyhow!(message)) } } } async fn parse_optional_json Deserialize<'de> + Default>( stream: &mut TcpStream, body: &[u8], ) -> Result { if body.is_empty() { return Ok(T::default()); } parse_json(stream, body).await } async fn write_json_or_error( stream: &mut TcpStream, result: Result, ) -> Result<()> { match result { Ok(value) => { let json = serde_json::to_string_pretty(&value)?; write_response(stream, 200, &json, "application/json").await?; } Err(err) => { let message = format!("reload failed: {err}"); write_response(stream, 400, &message, "text/plain").await?; } } Ok(()) } #[derive(Clone, Copy)] enum LogStream { Stdout, Stderr, } impl LogStream { fn parse(value: Option<&str>) -> Result { match value.unwrap_or("stdout") { "stdout" => Ok(Self::Stdout), "stderr" => Ok(Self::Stderr), other => Err(anyhow!( "invalid stream '{}': expected 'stdout' or 'stderr'", other )), } } fn as_str(self) -> &'static str { match self { Self::Stdout => "stdout", Self::Stderr => "stderr", } } } struct LogTailRequest { stream: LogStream, lines: usize, follow: bool, } impl LogTailRequest { fn parse(query: Option<&str>) -> Result { let stream = LogStream::parse(query_param(query, "stream"))?; let lines = query_param(query, "lines") .map(|value| { value .parse::() .map_err(|err| anyhow!("invalid lines '{}': {}", value, err)) }) .transpose()? .unwrap_or(200) .clamp(1, 5000); let follow = query_param(query, "follow") .map(|value| parse_bool_query(value, "follow")) .transpose()? .unwrap_or(true); Ok(Self { stream, lines, follow, }) } } async fn tail_log_stream( stream: &mut TcpStream, query: Option<&str>, config: &LogTailConfig, ) -> Result<()> { let request = match LogTailRequest::parse(query) { Ok(request) => request, Err(err) => { let message = format!("invalid log tail request: {err}"); write_response(stream, 400, &message, "text/plain").await?; return Ok(()); } }; let path = config.path_for(request.stream); if !Path::new(&path).is_file() { let message = format!("log file not found: {}", path.display()); write_response(stream, 404, &message, "text/plain").await?; return Ok(()); } write_chunked_headers(stream, "text/plain; charset=utf-8").await?; let (tail, mut offset) = read_tail(&path, request.lines).await?; if !tail.is_empty() { write_chunk(stream, &tail).await?; } if !request.follow { write_final_chunk(stream).await?; return Ok(()); } loop { sleep(Duration::from_secs(1)).await; let metadata = match tokio::fs::metadata(&path).await { Ok(metadata) => metadata, Err(err) => { let message = format!("\nlog file unavailable: {err}\n"); let _ = write_chunk(stream, message.as_bytes()).await; let _ = write_final_chunk(stream).await; return Ok(()); } }; let len = metadata.len(); if len < offset { offset = 0; write_chunk(stream, b"\nlog file truncated; restarting from beginning\n").await?; } if len == offset { continue; } let mut file = tokio::fs::File::open(&path).await?; file.seek(std::io::SeekFrom::Start(offset)).await?; let mut buf = Vec::new(); file.read_to_end(&mut buf).await?; offset = len; if !buf.is_empty() { write_chunk(stream, &buf).await?; } } } async fn read_tail(path: &Path, lines: usize) -> Result<(Vec, u64)> { const MAX_INITIAL_TAIL_BYTES: u64 = 1024 * 1024; let metadata = tokio::fs::metadata(path).await?; let len = metadata.len(); let start = len.saturating_sub(MAX_INITIAL_TAIL_BYTES); let mut file = tokio::fs::File::open(path).await?; file.seek(std::io::SeekFrom::Start(start)).await?; let mut buf = Vec::new(); file.read_to_end(&mut buf).await?; Ok((tail_log_lines(buf, lines), len)) } pub fn tail_log_lines(buf: Vec, lines: usize) -> Vec { if lines == 0 || buf.is_empty() { return Vec::new(); } let mut seen = 0usize; for (idx, byte) in buf.iter().enumerate().rev() { if *byte == b'\n' { seen += 1; if seen > lines { return buf[idx + 1..].to_vec(); } } } buf } fn query_param<'a>(query: Option<&'a str>, key: &str) -> Option<&'a str> { query?.split('&').find_map(|part| { let (name, value) = part.split_once('=').unwrap_or((part, "")); (name == key).then_some(value) }) } fn parse_bool_query(value: &str, name: &str) -> Result { match value { "true" | "1" | "yes" | "on" => Ok(true), "false" | "0" | "no" | "off" => Ok(false), _ => Err(anyhow!("invalid {} '{}': expected true/false", name, value)), } } async fn write_chunked_headers(stream: &mut TcpStream, content_type: &str) -> Result<()> { let response = format!( "HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ntransfer-encoding: chunked\r\ncache-control: no-store\r\nconnection: close\r\n\r\n" ); stream.write_all(response.as_bytes()).await?; Ok(()) } async fn write_chunk(stream: &mut TcpStream, chunk: &[u8]) -> Result<()> { if chunk.is_empty() { return Ok(()); } let header = format!("{:x}\r\n", chunk.len()); stream.write_all(header.as_bytes()).await?; stream.write_all(chunk).await?; stream.write_all(b"\r\n").await?; stream.flush().await?; Ok(()) } async fn write_final_chunk(stream: &mut TcpStream) -> Result<()> { stream.write_all(b"0\r\n\r\n").await?; Ok(()) } #[derive(Serialize)] struct AdminHealthResponse { status: &'static str, config_api: bool, source_reload_api: bool, slurm_api: bool, logs_api: bool, } #[derive(Serialize)] struct AdminConfigResponse { status: &'static str, config: RuntimeConfig, } #[derive(Serialize)] struct SlurmFileListResponse { status: &'static str, files: Vec, } #[derive(Serialize)] struct SlurmFileContentResponse { status: &'static str, file: crate::slurm::admin::SlurmFileContent, } #[derive(Serialize)] struct SlurmOperationResponse { status: &'static str, operation: SlurmFileOperationResult, reload: Option, rollback: Option, } #[derive(Serialize)] struct SlurmOperationErrorResponse { status: &'static str, error: String, rollback: Option, } #[derive(Serialize)] struct RollbackReport { status: String, reload: Option, error: Option, } struct RequestHeader { method: String, path: String, query: Option, content_length: Option, authorization: Option, } fn parse_request_header(header: &str) -> Result { let mut lines = header.lines(); let request_line = lines .next() .ok_or_else(|| anyhow!("missing HTTP request line"))?; let mut parts = request_line.split_whitespace(); let method = parts .next() .ok_or_else(|| anyhow!("missing HTTP method"))? .to_string(); let target = parts .next() .ok_or_else(|| anyhow!("missing HTTP path"))? .to_string(); let (path, query) = match target.split_once('?') { Some((path, query)) => (path.to_string(), Some(query.to_string())), None => (target, None), }; let mut content_length = None; let mut authorization = None; for line in lines { let Some((name, value)) = line.split_once(':') else { continue; }; let name = name.trim().to_ascii_lowercase(); let value = value.trim(); match name.as_str() { "content-length" => { content_length = Some( value .parse::() .map_err(|err| anyhow!("invalid content-length '{}': {}", value, err))?, ); } "authorization" => authorization = Some(value.to_string()), _ => {} } } Ok(RequestHeader { method, path, query, content_length, authorization, }) } fn find_header_end(buffer: &[u8]) -> Option { buffer.windows(4).position(|window| window == b"\r\n\r\n") } async fn write_response( stream: &mut TcpStream, status: u16, body: &str, content_type: &str, ) -> Result<()> { let reason = match status { 200 => "OK", 400 => "Bad Request", 401 => "Unauthorized", 404 => "Not Found", 413 => "Payload Too Large", _ => "Internal Server Error", }; let response = format!( "HTTP/1.1 {status} {reason}\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", body.len() ); stream.write_all(response.as_bytes()).await?; Ok(()) }