rpki/src/rtr/admin.rs
2026-06-23 17:04:00 +08:00

858 lines
26 KiB
Rust

use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use anyhow::{Context, Result, anyhow};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::sleep;
use tracing::{info, warn};
use crate::rtr::config::{RuntimeConfig, RuntimeConfigPatch};
use crate::slurm::admin::{
SlurmAdmin, SlurmFileActionRequest, SlurmFileOperationResult, SlurmFileWriteRequest,
parse_reload_query,
};
#[derive(Clone)]
pub struct RuntimeConfigHandle {
current: Arc<RwLock<RuntimeConfig>>,
tx: watch::Sender<RuntimeConfig>,
}
#[derive(Clone)]
pub struct SourceReloadHandle {
tx: mpsc::Sender<SourceReloadCommand>,
}
impl SourceReloadHandle {
pub fn new(tx: mpsc::Sender<SourceReloadCommand>) -> Self {
Self { tx }
}
pub async fn reload(&self, phase: &'static str, force: bool) -> Result<SourceReloadResult> {
let (respond_to, response) = oneshot::channel();
self.tx
.send(SourceReloadCommand {
phase,
force,
respond_to,
})
.await
.map_err(|_| anyhow!("source reload task is not available"))?;
response
.await
.map_err(|_| anyhow!("source reload task dropped the response"))?
.map_err(anyhow::Error::msg)
}
}
pub struct SourceReloadCommand {
pub phase: &'static str,
pub force: bool,
pub respond_to: oneshot::Sender<Result<SourceReloadResult, String>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SourceReloadResult {
pub phase: &'static str,
pub changed: bool,
pub skipped_unchanged: bool,
pub payload_count: Option<usize>,
pub serials: [u32; 3],
}
#[derive(Clone)]
pub struct AdminState {
runtime_config: RuntimeConfigHandle,
source_reload: Option<SourceReloadHandle>,
slurm_admin: Option<SlurmAdmin>,
log_tail: LogTailConfig,
}
impl AdminState {
pub fn new(
runtime_config: RuntimeConfigHandle,
source_reload: Option<SourceReloadHandle>,
slurm_admin: Option<SlurmAdmin>,
log_tail: LogTailConfig,
) -> Self {
Self {
runtime_config,
source_reload,
slurm_admin,
log_tail,
}
}
}
#[derive(Clone)]
pub struct LogTailConfig {
dir: PathBuf,
name: String,
}
impl LogTailConfig {
pub fn from_env() -> Self {
let dir = std::env::var_os("RPKI_RTR_LOG_DIR")
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/app/logs"));
let name = std::env::var("RPKI_RTR_LOG_NAME")
.or_else(|_| std::env::var("HOSTNAME"))
.unwrap_or_else(|_| "rpki-rtr".to_string());
Self { dir, name }
}
fn path_for(&self, stream: LogStream) -> PathBuf {
self.dir
.join(format!("{}.{}.log", self.name, stream.as_str()))
}
}
impl RuntimeConfigHandle {
pub fn new(config: RuntimeConfig) -> Self {
let (tx, _) = watch::channel(config.clone());
Self {
current: Arc::new(RwLock::new(config)),
tx,
}
}
pub fn current(&self) -> RuntimeConfig {
self.current
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.clone()
}
pub fn subscribe(&self) -> watch::Receiver<RuntimeConfig> {
self.tx.subscribe()
}
pub fn apply_patch(&self, patch: RuntimeConfigPatch) -> Result<RuntimeConfig> {
let next = self.current().apply_patch(patch)?;
{
let mut current = self
.current
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*current = next.clone();
}
let _ = self.tx.send_replace(next.clone());
Ok(next)
}
}
pub fn spawn_admin_config_server(
addr: SocketAddr,
token: Option<String>,
state: AdminState,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
if let Err(err) = run_admin_config_server(addr, token, state).await {
warn!("RTR admin config server exited: {:?}", err);
}
})
}
async fn run_admin_config_server(
addr: SocketAddr,
token: Option<String>,
state: AdminState,
) -> Result<()> {
if token.is_none() && !addr.ip().is_loopback() {
return Err(anyhow!(
"RPKI_RTR_ADMIN_TOKEN is required when admin addr is not loopback: {}",
addr
));
}
let listener = TcpListener::bind(addr)
.await
.with_context(|| format!("bind RTR admin config server on {addr}"))?;
info!("RTR admin config server listening on {}", addr);
loop {
let (stream, peer_addr) = listener.accept().await?;
let token = token.clone();
let state = state.clone();
tokio::spawn(async move {
if let Err(err) = handle_admin_connection(stream, peer_addr, token, state).await {
warn!(
"RTR admin config request failed from {}: {:?}",
peer_addr, err
);
}
});
}
}
async fn handle_admin_connection(
mut stream: TcpStream,
peer_addr: SocketAddr,
token: Option<String>,
state: AdminState,
) -> Result<()> {
let mut buffer = vec![0u8; 64 * 1024];
let mut read = 0usize;
let header_end = loop {
if read == buffer.len() {
write_response(&mut stream, 413, "payload too large", "text/plain").await?;
return Ok(());
}
let n = stream.read(&mut buffer[read..]).await?;
if n == 0 {
return Ok(());
}
read += n;
if let Some(pos) = find_header_end(&buffer[..read]) {
break pos;
}
};
let header = std::str::from_utf8(&buffer[..header_end])
.map_err(|err| anyhow!("invalid HTTP header from {}: {}", peer_addr, err))?;
let request = parse_request_header(header)?;
if let Some(token) = token.as_deref() {
let expected = format!("Bearer {token}");
if request.authorization.as_deref() != Some(expected.as_str()) {
write_response(&mut stream, 401, "unauthorized", "text/plain").await?;
return Ok(());
}
}
let content_length = request.content_length.unwrap_or(0);
let max_body_bytes = if request.path.starts_with("/admin/rtr/slurm/") {
state
.slurm_admin
.as_ref()
.map(SlurmAdmin::max_body_bytes)
.unwrap_or(32 * 1024)
} else {
32 * 1024
};
if content_length > max_body_bytes {
write_response(&mut stream, 413, "payload too large", "text/plain").await?;
return Ok(());
}
let body_start = header_end + 4;
let available_body = read.saturating_sub(body_start);
let mut body = Vec::with_capacity(content_length);
body.extend_from_slice(&buffer[body_start..read]);
if available_body < content_length {
let remaining = content_length - available_body;
let mut tail = vec![0u8; remaining];
stream.read_exact(&mut tail).await?;
body.extend_from_slice(&tail);
}
body.truncate(content_length);
route_admin_request(&mut stream, request, body, state, peer_addr).await
}
async fn route_admin_request(
stream: &mut TcpStream,
request: RequestHeader,
body: Vec<u8>,
state: AdminState,
peer_addr: SocketAddr,
) -> Result<()> {
if request.method == "POST" && request.path == "/admin/rtr/config" {
let patch = match serde_json::from_slice::<RuntimeConfigPatch>(&body) {
Ok(patch) => patch,
Err(err) => {
let message = format!("invalid json: {err}");
write_response(stream, 400, &message, "text/plain").await?;
return Ok(());
}
};
match state.runtime_config.apply_patch(patch) {
Ok(config) => {
info!("RTR admin config updated from {}: {:?}", peer_addr, config);
let json = serde_json::to_string_pretty(&AdminConfigResponse {
status: "ok",
config,
})?;
write_response(stream, 200, &json, "application/json").await?;
}
Err(err) => {
let message = format!("invalid config: {err}");
write_response(stream, 400, &message, "text/plain").await?;
}
}
return Ok(());
}
if request.method == "GET" && request.path == "/admin/rtr/health" {
let json = serde_json::to_string_pretty(&AdminHealthResponse {
status: "ok",
config_api: true,
source_reload_api: state.source_reload.is_some(),
slurm_api: state.slurm_admin.is_some(),
logs_api: true,
})?;
write_response(stream, 200, &json, "application/json").await?;
return Ok(());
}
if request.method == "GET" && request.path == "/admin/rtr/logs/tail" {
return tail_log_stream(stream, request.query.as_deref(), &state.log_tail).await;
}
if request.method == "GET" && request.path == "/admin/rtr/config" {
let json = serde_json::to_string_pretty(&AdminConfigResponse {
status: "ok",
config: state.runtime_config.current(),
})?;
write_response(stream, 200, &json, "application/json").await?;
return Ok(());
}
if request.method == "POST" && request.path == "/admin/rtr/slurm/reload" {
let reload = reload_source(&state, "admin_slurm_reload", true).await;
return write_json_or_error(stream, reload).await;
}
if request.path == "/admin/rtr/slurm/files" && request.method == "GET" {
let Some(slurm_admin) = state.slurm_admin.as_ref() else {
write_response(stream, 400, "SLURM admin is disabled", "text/plain").await?;
return Ok(());
};
let json = serde_json::to_string_pretty(&SlurmFileListResponse {
status: "ok",
files: slurm_admin.list_files()?,
})?;
write_response(stream, 200, &json, "application/json").await?;
return Ok(());
}
if request.path == "/admin/rtr/slurm/files" && request.method == "POST" {
let Some(slurm_admin) = state.slurm_admin.as_ref() else {
write_response(stream, 400, "SLURM admin is disabled", "text/plain").await?;
return Ok(());
};
let request_body = parse_json::<SlurmFileWriteRequest>(stream, &body).await?;
let Some(name) = request_body.name.as_deref() else {
write_response(stream, 400, "missing SLURM file name", "text/plain").await?;
return Ok(());
};
let reload = request_body.reload.unwrap_or(false);
let operation = slurm_admin.put_file(name, &request_body.content, "create_or_update");
return apply_slurm_operation(stream, &state, operation, reload).await;
}
let slurm_file_prefix = "/admin/rtr/slurm/files/";
if request.path.starts_with(slurm_file_prefix) {
let rest = request.path[slurm_file_prefix.len()..].to_string();
return route_slurm_file_request(stream, request, body, state, rest).await;
}
write_response(stream, 404, "not found", "text/plain").await?;
Ok(())
}
async fn route_slurm_file_request(
stream: &mut TcpStream,
request: RequestHeader,
body: Vec<u8>,
state: AdminState,
rest: String,
) -> Result<()> {
let Some(slurm_admin) = state.slurm_admin.as_ref() else {
write_response(stream, 400, "SLURM admin is disabled", "text/plain").await?;
return Ok(());
};
if request.method == "GET" {
let file = match slurm_admin.read_file(&rest) {
Ok(file) => file,
Err(err) => {
let message = format!("invalid SLURM file request: {err}");
write_response(stream, 400, &message, "text/plain").await?;
return Ok(());
}
};
let json = serde_json::to_string_pretty(&SlurmFileContentResponse { status: "ok", file })?;
write_response(stream, 200, &json, "application/json").await?;
return Ok(());
}
if request.method == "PUT" {
let request_body = parse_json::<SlurmFileWriteRequest>(stream, &body).await?;
let reload = parse_reload_query(request.query.as_deref(), request_body.reload);
let operation = slurm_admin.put_file(&rest, &request_body.content, "create_or_update");
return apply_slurm_operation(stream, &state, operation, reload).await;
}
if request.method == "DELETE" {
let reload = if body.is_empty() {
parse_reload_query(request.query.as_deref(), None)
} else {
let request_body = parse_json::<SlurmFileActionRequest>(stream, &body).await?;
parse_reload_query(request.query.as_deref(), request_body.reload)
};
let operation = slurm_admin.delete_file(&rest);
return apply_slurm_operation(stream, &state, operation, reload).await;
}
if request.method == "POST" {
if let Some(name) = rest.strip_suffix("/enable") {
let request_body: SlurmFileActionRequest = parse_optional_json(stream, &body).await?;
let reload = parse_reload_query(request.query.as_deref(), request_body.reload);
let operation = slurm_admin.enable_file(name);
return apply_slurm_operation(stream, &state, operation, reload).await;
}
if let Some(name) = rest.strip_suffix("/disable") {
let request_body: SlurmFileActionRequest = parse_optional_json(stream, &body).await?;
let reload = parse_reload_query(request.query.as_deref(), request_body.reload);
let operation = slurm_admin.disable_file(name);
return apply_slurm_operation(stream, &state, operation, reload).await;
}
}
write_response(stream, 404, "not found", "text/plain").await?;
Ok(())
}
async fn apply_slurm_operation(
stream: &mut TcpStream,
state: &AdminState,
operation: Result<crate::slurm::admin::AppliedSlurmFileOperation>,
reload: bool,
) -> Result<()> {
let operation = match operation {
Ok(operation) => operation,
Err(err) => {
let message = format!("invalid SLURM operation: {err}");
write_response(stream, 400, &message, "text/plain").await?;
return Ok(());
}
};
if !reload {
let json = serde_json::to_string_pretty(&SlurmOperationResponse {
status: "ok",
operation: operation.result,
reload: None,
rollback: None,
})?;
write_response(stream, 200, &json, "application/json").await?;
return Ok(());
}
match reload_source(state, "admin_slurm_changed", true).await {
Ok(result) => {
let json = serde_json::to_string_pretty(&SlurmOperationResponse {
status: "ok",
operation: operation.result,
reload: Some(result),
rollback: None,
})?;
write_response(stream, 200, &json, "application/json").await?;
}
Err(err) => {
let rollback = match operation.rollback() {
Ok(()) => match reload_source(state, "admin_slurm_rollback", true).await {
Ok(result) => Some(RollbackReport {
status: "ok".to_string(),
reload: Some(result),
error: None,
}),
Err(reload_err) => Some(RollbackReport {
status: "reload_failed".to_string(),
reload: None,
error: Some(reload_err.to_string()),
}),
},
Err(rollback_err) => Some(RollbackReport {
status: "failed".to_string(),
reload: None,
error: Some(rollback_err.to_string()),
}),
};
let json = serde_json::to_string_pretty(&SlurmOperationErrorResponse {
status: "reload_failed",
error: err.to_string(),
rollback,
})?;
write_response(stream, 400, &json, "application/json").await?;
}
}
Ok(())
}
async fn reload_source(
state: &AdminState,
phase: &'static str,
force: bool,
) -> Result<SourceReloadResult> {
let Some(source_reload) = state.source_reload.as_ref() else {
return Err(anyhow!("source reload is not available"));
};
source_reload.reload(phase, force).await
}
async fn parse_json<T: for<'de> Deserialize<'de>>(
stream: &mut TcpStream,
body: &[u8],
) -> Result<T> {
match serde_json::from_slice::<T>(body) {
Ok(value) => Ok(value),
Err(err) => {
let message = format!("invalid json: {err}");
write_response(stream, 400, &message, "text/plain").await?;
Err(anyhow!(message))
}
}
}
async fn parse_optional_json<T: for<'de> Deserialize<'de> + Default>(
stream: &mut TcpStream,
body: &[u8],
) -> Result<T> {
if body.is_empty() {
return Ok(T::default());
}
parse_json(stream, body).await
}
async fn write_json_or_error<T: Serialize>(
stream: &mut TcpStream,
result: Result<T>,
) -> Result<()> {
match result {
Ok(value) => {
let json = serde_json::to_string_pretty(&value)?;
write_response(stream, 200, &json, "application/json").await?;
}
Err(err) => {
let message = format!("reload failed: {err}");
write_response(stream, 400, &message, "text/plain").await?;
}
}
Ok(())
}
#[derive(Clone, Copy)]
enum LogStream {
Stdout,
Stderr,
}
impl LogStream {
fn parse(value: Option<&str>) -> Result<Self> {
match value.unwrap_or("stdout") {
"stdout" => Ok(Self::Stdout),
"stderr" => Ok(Self::Stderr),
other => Err(anyhow!(
"invalid stream '{}': expected 'stdout' or 'stderr'",
other
)),
}
}
fn as_str(self) -> &'static str {
match self {
Self::Stdout => "stdout",
Self::Stderr => "stderr",
}
}
}
struct LogTailRequest {
stream: LogStream,
lines: usize,
follow: bool,
}
impl LogTailRequest {
fn parse(query: Option<&str>) -> Result<Self> {
let stream = LogStream::parse(query_param(query, "stream"))?;
let lines = query_param(query, "lines")
.map(|value| {
value
.parse::<usize>()
.map_err(|err| anyhow!("invalid lines '{}': {}", value, err))
})
.transpose()?
.unwrap_or(200)
.clamp(1, 5000);
let follow = query_param(query, "follow")
.map(|value| parse_bool_query(value, "follow"))
.transpose()?
.unwrap_or(true);
Ok(Self {
stream,
lines,
follow,
})
}
}
async fn tail_log_stream(
stream: &mut TcpStream,
query: Option<&str>,
config: &LogTailConfig,
) -> Result<()> {
let request = match LogTailRequest::parse(query) {
Ok(request) => request,
Err(err) => {
let message = format!("invalid log tail request: {err}");
write_response(stream, 400, &message, "text/plain").await?;
return Ok(());
}
};
let path = config.path_for(request.stream);
if !Path::new(&path).is_file() {
let message = format!("log file not found: {}", path.display());
write_response(stream, 404, &message, "text/plain").await?;
return Ok(());
}
write_chunked_headers(stream, "text/plain; charset=utf-8").await?;
let (tail, mut offset) = read_tail(&path, request.lines).await?;
if !tail.is_empty() {
write_chunk(stream, &tail).await?;
}
if !request.follow {
write_final_chunk(stream).await?;
return Ok(());
}
loop {
sleep(Duration::from_secs(1)).await;
let metadata = match tokio::fs::metadata(&path).await {
Ok(metadata) => metadata,
Err(err) => {
let message = format!("\nlog file unavailable: {err}\n");
let _ = write_chunk(stream, message.as_bytes()).await;
let _ = write_final_chunk(stream).await;
return Ok(());
}
};
let len = metadata.len();
if len < offset {
offset = 0;
write_chunk(stream, b"\nlog file truncated; restarting from beginning\n").await?;
}
if len == offset {
continue;
}
let mut file = tokio::fs::File::open(&path).await?;
file.seek(std::io::SeekFrom::Start(offset)).await?;
let mut buf = Vec::new();
file.read_to_end(&mut buf).await?;
offset = len;
if !buf.is_empty() {
write_chunk(stream, &buf).await?;
}
}
}
async fn read_tail(path: &Path, lines: usize) -> Result<(Vec<u8>, u64)> {
const MAX_INITIAL_TAIL_BYTES: u64 = 1024 * 1024;
let metadata = tokio::fs::metadata(path).await?;
let len = metadata.len();
let start = len.saturating_sub(MAX_INITIAL_TAIL_BYTES);
let mut file = tokio::fs::File::open(path).await?;
file.seek(std::io::SeekFrom::Start(start)).await?;
let mut buf = Vec::new();
file.read_to_end(&mut buf).await?;
Ok((tail_log_lines(buf, lines), len))
}
pub fn tail_log_lines(buf: Vec<u8>, lines: usize) -> Vec<u8> {
if lines == 0 || buf.is_empty() {
return Vec::new();
}
let mut seen = 0usize;
for (idx, byte) in buf.iter().enumerate().rev() {
if *byte == b'\n' {
seen += 1;
if seen > lines {
return buf[idx + 1..].to_vec();
}
}
}
buf
}
fn query_param<'a>(query: Option<&'a str>, key: &str) -> Option<&'a str> {
query?.split('&').find_map(|part| {
let (name, value) = part.split_once('=').unwrap_or((part, ""));
(name == key).then_some(value)
})
}
fn parse_bool_query(value: &str, name: &str) -> Result<bool> {
match value {
"true" | "1" | "yes" | "on" => Ok(true),
"false" | "0" | "no" | "off" => Ok(false),
_ => Err(anyhow!("invalid {} '{}': expected true/false", name, value)),
}
}
async fn write_chunked_headers(stream: &mut TcpStream, content_type: &str) -> Result<()> {
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ntransfer-encoding: chunked\r\ncache-control: no-store\r\nconnection: close\r\n\r\n"
);
stream.write_all(response.as_bytes()).await?;
Ok(())
}
async fn write_chunk(stream: &mut TcpStream, chunk: &[u8]) -> Result<()> {
if chunk.is_empty() {
return Ok(());
}
let header = format!("{:x}\r\n", chunk.len());
stream.write_all(header.as_bytes()).await?;
stream.write_all(chunk).await?;
stream.write_all(b"\r\n").await?;
stream.flush().await?;
Ok(())
}
async fn write_final_chunk(stream: &mut TcpStream) -> Result<()> {
stream.write_all(b"0\r\n\r\n").await?;
Ok(())
}
#[derive(Serialize)]
struct AdminHealthResponse {
status: &'static str,
config_api: bool,
source_reload_api: bool,
slurm_api: bool,
logs_api: bool,
}
#[derive(Serialize)]
struct AdminConfigResponse {
status: &'static str,
config: RuntimeConfig,
}
#[derive(Serialize)]
struct SlurmFileListResponse {
status: &'static str,
files: Vec<crate::slurm::admin::SlurmFileListEntry>,
}
#[derive(Serialize)]
struct SlurmFileContentResponse {
status: &'static str,
file: crate::slurm::admin::SlurmFileContent,
}
#[derive(Serialize)]
struct SlurmOperationResponse {
status: &'static str,
operation: SlurmFileOperationResult,
reload: Option<SourceReloadResult>,
rollback: Option<RollbackReport>,
}
#[derive(Serialize)]
struct SlurmOperationErrorResponse {
status: &'static str,
error: String,
rollback: Option<RollbackReport>,
}
#[derive(Serialize)]
struct RollbackReport {
status: String,
reload: Option<SourceReloadResult>,
error: Option<String>,
}
struct RequestHeader {
method: String,
path: String,
query: Option<String>,
content_length: Option<usize>,
authorization: Option<String>,
}
fn parse_request_header(header: &str) -> Result<RequestHeader> {
let mut lines = header.lines();
let request_line = lines
.next()
.ok_or_else(|| anyhow!("missing HTTP request line"))?;
let mut parts = request_line.split_whitespace();
let method = parts
.next()
.ok_or_else(|| anyhow!("missing HTTP method"))?
.to_string();
let target = parts
.next()
.ok_or_else(|| anyhow!("missing HTTP path"))?
.to_string();
let (path, query) = match target.split_once('?') {
Some((path, query)) => (path.to_string(), Some(query.to_string())),
None => (target, None),
};
let mut content_length = None;
let mut authorization = None;
for line in lines {
let Some((name, value)) = line.split_once(':') else {
continue;
};
let name = name.trim().to_ascii_lowercase();
let value = value.trim();
match name.as_str() {
"content-length" => {
content_length = Some(
value
.parse::<usize>()
.map_err(|err| anyhow!("invalid content-length '{}': {}", value, err))?,
);
}
"authorization" => authorization = Some(value.to_string()),
_ => {}
}
}
Ok(RequestHeader {
method,
path,
query,
content_length,
authorization,
})
}
fn find_header_end(buffer: &[u8]) -> Option<usize> {
buffer.windows(4).position(|window| window == b"\r\n\r\n")
}
async fn write_response(
stream: &mut TcpStream,
status: u16,
body: &str,
content_type: &str,
) -> Result<()> {
let reason = match status {
200 => "OK",
400 => "Bad Request",
401 => "Unauthorized",
404 => "Not Found",
413 => "Payload Too Large",
_ => "Internal Server Error",
};
let response = format!(
"HTTP/1.1 {status} {reason}\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).await?;
Ok(())
}