2026-04-01 16:24:01 +08:00

1037 lines
35 KiB
Rust

use std::env;
use std::io;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio::io::{self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, WriteHalf};
use tokio::net::TcpStream;
use tokio::time::{Duration, Instant, timeout};
use tokio_rustls::TlsConnector;
mod pretty;
mod protocol;
mod wire;
use crate::pretty::{parse_end_of_data_info, parse_serial_notify_serial, print_pdu, print_raw_pdu};
use crate::protocol::{PduHeader, PduType, QueryMode};
use crate::wire::{read_pdu, send_reset_query, send_serial_query};
const DEFAULT_READ_TIMEOUT_SECS: u64 = 30;
const DEFAULT_POLL_INTERVAL_SECS: u64 = 600;
trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
type DynStream = Box<dyn AsyncStream>;
type ClientWriter = WriteHalf<DynStream>;
#[tokio::main]
async fn main() -> io::Result<()> {
let config = Config::from_args()?;
println!("== RTR debug client ==");
println!("target : {}", config.addr);
println!("transport: {}", config.transport.describe());
println!("version : {}", config.version);
println!("timeout : {}s", config.read_timeout_secs);
println!(
"poll : {}s (default before EndOfData refresh is known)",
config.default_poll_secs
);
println!("keep-after-error: {}", config.keep_after_error);
match &config.mode {
QueryMode::Reset => {
println!("mode : reset");
}
QueryMode::Serial { session_id, serial } => {
println!("mode : serial");
println!("session : {}", session_id);
println!("serial : {}", serial);
}
}
println!();
print_help();
let mut state = ClientState::new(
config.version,
config.read_timeout_secs,
config.default_poll_secs,
config.keep_after_error,
);
let stdin = tokio::io::stdin();
let mut stdin_lines = BufReader::new(stdin).lines();
loop {
let stream = loop {
match connect_stream(&config).await {
Ok(stream) => {
println!("connected to {}", config.addr);
break stream;
}
Err(err) => {
let delay = state.reconnect_delay_secs();
eprintln!("connect failed: {}. retry after {}s", err, delay);
tokio::time::sleep(Duration::from_secs(delay)).await;
}
}
};
let (mut reader, mut writer) = tokio_io::split(stream);
send_resume_query(&mut writer, &mut state, &config.mode).await?;
state.schedule_next_poll();
println!();
let reconnect = loop {
let poll_sleep = tokio::time::sleep_until(state.poll_deadline());
tokio::pin!(poll_sleep);
tokio::select! {
line = stdin_lines.next_line() => {
match line {
Ok(Some(line)) => {
match handle_console_command(
&line,
&mut writer,
&mut state,
).await {
Ok(should_quit) => {
if should_quit {
println!("quit requested, closing client.");
return Ok(());
}
}
Err(err) if should_reconnect(&err) => {
eprintln!("command failed due to disconnected transport: {}", err);
break true;
}
Err(err) => return Err(err),
}
}
Ok(None) => {
println!("stdin closed, continue network loop.");
}
Err(err) => {
eprintln!("read stdin failed: {}", err);
}
}
}
_ = &mut poll_sleep => {
match handle_poll_tick(&mut writer, &mut state).await {
Ok(()) => state.schedule_next_poll(),
Err(err) if should_reconnect(&err) => {
eprintln!("auto poll failed due to disconnected transport: {}", err);
break true;
}
Err(err) => return Err(err),
}
}
read_result = timeout(
Duration::from_secs(state.read_timeout_secs),
read_pdu(&mut reader)
) => {
match read_result {
Ok(Ok(pdu)) => {
print_raw_pdu(&pdu.header, &pdu.body);
print_pdu(&pdu.header, &pdu.body);
match handle_incoming_pdu(&mut writer, &mut state, &pdu.header, &pdu.body).await {
Ok(()) => {}
Err(err) if should_reconnect(&err) => {
eprintln!("connection dropped while handling incoming PDU: {}", err);
break true;
}
Err(err) => return Err(err),
}
}
Ok(Err(err)) => {
eprintln!("read PDU failed: {}", err);
if should_reconnect(&err) {
break true;
}
return Err(err);
}
Err(_) => {
println!(
"[timeout] no PDU received in {}s, connection kept open.",
state.read_timeout_secs
);
}
}
}
}
};
if reconnect {
let delay = state.reconnect_delay_secs();
state.current_session_id = None;
println!("[reconnect] transport disconnected, retry after {}s", delay);
tokio::time::sleep(Duration::from_secs(delay)).await;
}
}
}
async fn send_resume_query(
writer: &mut ClientWriter,
state: &mut ClientState,
mode: &QueryMode,
) -> io::Result<()> {
match (state.session_id, state.serial) {
(Some(session_id), Some(serial)) => {
println!(
"reconnected, send Serial Query with session_id={}, serial={}",
session_id, serial
);
send_serial_query(writer, state.version, session_id, serial).await?;
}
_ => match mode {
QueryMode::Reset => {
send_reset_query(writer, state.version).await?;
println!("sent Reset Query");
}
QueryMode::Serial { session_id, serial } => {
state.session_id = Some(*session_id);
state.serial = Some(*serial);
send_serial_query(writer, state.version, *session_id, *serial).await?;
println!("sent Serial Query");
}
},
}
Ok(())
}
fn should_reconnect(err: &io::Error) -> bool {
matches!(
err.kind(),
io::ErrorKind::UnexpectedEof
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
| io::ErrorKind::BrokenPipe
| io::ErrorKind::NotConnected
)
}
async fn handle_incoming_pdu(
writer: &mut ClientWriter,
state: &mut ClientState,
header: &PduHeader,
body: &[u8],
) -> io::Result<()> {
match header.pdu_type() {
PduType::CacheResponse => {
state.current_session_id = Some(header.session_id());
state.last_error_code = None;
}
PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::RouterKey | PduType::Aspa => {
if state.current_session_id.is_none() {
state.current_session_id = Some(header.session_id());
}
}
PduType::EndOfData => {
let session_id = header.session_id();
let eod = parse_end_of_data_info(body);
state.session_id = Some(session_id);
state.current_session_id = Some(session_id);
println!();
if let Some(eod) = eod {
state.serial = Some(eod.serial);
state.refresh = eod.refresh;
state.retry = eod.retry;
state.expire = eod.expire;
state.last_error_code = None;
println!(
"updated client state: session_id={}, serial={}",
session_id, eod.serial
);
if let Some(refresh) = eod.refresh {
println!("refresh : {}", refresh);
}
if let Some(retry) = eod.retry {
println!("retry : {}", retry);
}
if let Some(expire) = eod.expire {
println!("expire : {}", expire);
}
state.schedule_next_poll();
println!(
"next auto poll scheduled after {}s",
state.effective_poll_secs()
);
} else {
println!(
"updated client state: session_id={}, serial=<unknown>",
session_id
);
}
println!("received EndOfData, keep connection open.");
println!();
}
PduType::SerialNotify => {
let notify_session_id = header.session_id();
let notify_serial = parse_serial_notify_serial(body);
println!();
println!(
"[notify] received Serial Notify: session_id={}, notify_serial={:?}",
notify_session_id, notify_serial
);
match (state.session_id, state.serial, notify_serial) {
(Some(current_session_id), Some(current_serial), Some(_new_serial))
if current_session_id == notify_session_id =>
{
println!(
"received Serial Notify for current session {}, send Serial Query with serial {}",
current_session_id, current_serial
);
send_serial_query(writer, state.version, current_session_id, current_serial)
.await?;
}
_ => {
println!(
"received Serial Notify but local session/serial state is not usable, send Reset Query"
);
send_reset_query(writer, state.version).await?;
}
}
state.schedule_next_poll();
println!();
}
PduType::CacheReset => {
println!();
println!("received Cache Reset, send Reset Query");
state.current_session_id = None;
state.serial = None;
state.last_error_code = None;
send_reset_query(writer, state.version).await?;
state.schedule_next_poll();
println!();
}
PduType::ErrorReport => {
println!();
println!("received Error Report, pause auto polling for debugging.");
state.last_error_code = Some(header.error_code());
if let Some(retry) = state.retry {
println!("server retry hint currently stored: {}s", retry);
if state.should_prefer_retry_poll() {
println!("when resumed, auto polling will use retry instead of refresh.");
}
}
if state.keep_after_error {
println!("keep-after-error is enabled, auto polling will continue.");
state.schedule_next_poll();
} else {
println!("use `reset`, `serial`, or `poll resume` to continue manually.");
state.pause_auto_poll();
}
println!();
}
PduType::SerialQuery | PduType::ResetQuery | PduType::Unknown(_) => {
// only print, no extra action
}
}
Ok(())
}
async fn handle_poll_tick(writer: &mut ClientWriter, state: &mut ClientState) -> io::Result<()> {
println!();
println!(
"[auto-poll] timer fired (interval={}s)",
state.effective_poll_secs()
);
match (state.session_id, state.serial) {
(Some(session_id), Some(serial)) => {
println!(
"[auto-poll] send Serial Query with session_id={}, serial={}",
session_id, serial
);
send_serial_query(writer, state.version, session_id, serial).await?;
}
_ => {
println!("[auto-poll] local state incomplete, send Reset Query");
send_reset_query(writer, state.version).await?;
}
}
println!();
Ok(())
}
async fn handle_console_command(
line: &str,
writer: &mut ClientWriter,
state: &mut ClientState,
) -> io::Result<bool> {
let line = line.trim();
if line.is_empty() {
return Ok(false);
}
let parts: Vec<&str> = line.split_whitespace().collect();
match parts.as_slice() {
["help"] => {
print_help();
}
["state"] => {
print_state(state);
}
["reset"] => {
println!("manual command: send Reset Query");
send_reset_query(writer, state.version).await?;
state.schedule_next_poll();
}
["serial"] => match (state.session_id, state.serial) {
(Some(session_id), Some(serial)) => {
println!(
"manual command: send Serial Query with current state: session_id={}, serial={}",
session_id, serial
);
send_serial_query(writer, state.version, session_id, serial).await?;
state.schedule_next_poll();
}
_ => {
println!(
"manual command failed: current session_id/serial not available, use `reset` or `serial <session_id> <serial>`"
);
}
},
["serial", session_id, serial] => {
let session_id = match session_id.parse::<u16>() {
Ok(v) => v,
Err(err) => {
println!("invalid session_id: {}", err);
return Ok(false);
}
};
let serial = match serial.parse::<u32>() {
Ok(v) => v,
Err(err) => {
println!("invalid serial: {}", err);
return Ok(false);
}
};
println!(
"manual command: send Serial Query with explicit args: session_id={}, serial={}",
session_id, serial
);
state.session_id = Some(session_id);
state.serial = Some(serial);
send_serial_query(writer, state.version, session_id, serial).await?;
state.schedule_next_poll();
}
["timeout"] => {
println!("current read timeout: {}s", state.read_timeout_secs);
}
["timeout", secs] => {
let secs = match secs.parse::<u64>() {
Ok(v) if v > 0 => v,
Ok(_) => {
println!("timeout must be > 0");
return Ok(false);
}
Err(err) => {
println!("invalid timeout seconds: {}", err);
return Ok(false);
}
};
state.read_timeout_secs = secs;
println!("updated read timeout to {}s", state.read_timeout_secs);
}
["poll"] => {
println!(
"current effective poll interval: {}s",
state.effective_poll_secs()
);
println!(
"poll interval source : {}",
state.poll_interval_source()
);
println!("stored refresh hint : {:?}", state.refresh);
println!("default poll interval : {}s", state.default_poll_secs);
println!("last_error_code : {:?}", state.last_error_code);
println!("auto polling paused : {}", state.poll_paused);
}
["poll", "pause"] => {
state.pause_auto_poll();
println!("auto polling paused");
}
["poll", "resume"] => {
state.resume_auto_poll();
println!(
"auto polling resumed, next poll scheduled after {}s",
state.effective_poll_secs()
);
}
["poll", secs] => {
let secs = match secs.parse::<u64>() {
Ok(v) if v > 0 => v,
Ok(_) => {
println!("poll interval must be > 0");
return Ok(false);
}
Err(err) => {
println!("invalid poll seconds: {}", err);
return Ok(false);
}
};
state.refresh = Some(secs as u32);
state.schedule_next_poll();
println!("updated poll interval to {}s", secs);
}
["quit"] | ["exit"] => {
return Ok(true);
}
_ => {
println!("unknown command: {}", line);
print_help();
}
}
Ok(false)
}
fn print_help() {
println!("available commands:");
println!(" help show this help");
println!(" state print current client state");
println!(" reset send Reset Query");
println!(" serial send Serial Query with current session_id/serial");
println!(" serial <sid> <serial> send Serial Query with explicit values");
println!(" timeout show current read timeout");
println!(" timeout <secs> update read timeout seconds");
println!(" poll show current poll interval");
println!(" poll <secs> override poll interval seconds");
println!(" poll pause pause auto polling");
println!(" poll resume resume auto polling");
println!(" keep-after-error show current keep-after-error setting");
println!(" quit exit client");
println!();
}
fn print_state(state: &ClientState) {
println!("client state:");
println!(" version : {}", state.version);
println!(" session_id : {:?}", state.session_id);
println!(" serial : {:?}", state.serial);
println!(" current_session_id : {:?}", state.current_session_id);
println!(" refresh : {:?}", state.refresh);
println!(" retry : {:?}", state.retry);
println!(" expire : {:?}", state.expire);
println!(" read_timeout_secs : {}", state.read_timeout_secs);
println!(" default_poll_secs : {}", state.default_poll_secs);
println!(" effective_poll_secs: {}", state.effective_poll_secs());
println!(" poll_source : {}", state.poll_interval_source());
println!(" last_error_code : {:?}", state.last_error_code);
println!(" keep_after_error : {}", state.keep_after_error);
println!(" poll_paused : {}", state.poll_paused);
println!();
}
#[derive(Debug)]
struct ClientState {
version: u8,
session_id: Option<u16>,
serial: Option<u32>,
current_session_id: Option<u16>,
refresh: Option<u32>,
retry: Option<u32>,
expire: Option<u32>,
last_error_code: Option<u16>,
keep_after_error: bool,
read_timeout_secs: u64,
default_poll_secs: u64,
next_poll_deadline: Instant,
poll_paused: bool,
}
impl ClientState {
fn new(
version: u8,
read_timeout_secs: u64,
default_poll_secs: u64,
keep_after_error: bool,
) -> Self {
Self {
version,
session_id: None,
serial: None,
current_session_id: None,
refresh: None,
retry: None,
expire: None,
last_error_code: None,
keep_after_error,
read_timeout_secs,
default_poll_secs,
next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs),
poll_paused: false,
}
}
fn effective_poll_secs(&self) -> u64 {
if self.should_prefer_retry_poll() {
self.retry.map(|v| v as u64).unwrap_or_else(|| {
self.refresh
.map(|v| v as u64)
.unwrap_or(self.default_poll_secs)
})
} else {
self.refresh
.map(|v| v as u64)
.unwrap_or(self.default_poll_secs)
}
}
fn schedule_next_poll(&mut self) {
self.next_poll_deadline = Instant::now() + Duration::from_secs(self.effective_poll_secs());
}
fn pause_auto_poll(&mut self) {
self.poll_paused = true;
}
fn resume_auto_poll(&mut self) {
self.poll_paused = false;
self.schedule_next_poll();
}
fn poll_deadline(&self) -> Instant {
if self.poll_paused {
Instant::now() + Duration::from_secs(365 * 24 * 60 * 60)
} else {
self.next_poll_deadline
}
}
fn should_prefer_retry_poll(&self) -> bool {
matches!(self.last_error_code, Some(2 | 10))
}
fn poll_interval_source(&self) -> &'static str {
if self.should_prefer_retry_poll() && self.retry.is_some() {
"retry"
} else if self.refresh.is_some() {
"refresh"
} else {
"default"
}
}
fn reconnect_delay_secs(&self) -> u64 {
if self.should_prefer_retry_poll() {
self.retry
.map(|v| v as u64)
.unwrap_or(self.default_poll_secs)
} else {
self.default_poll_secs
}
}
}
#[derive(Debug)]
struct Config {
addr: String,
version: u8,
mode: QueryMode,
read_timeout_secs: u64,
default_poll_secs: u64,
transport: TransportConfig,
keep_after_error: bool,
}
impl Config {
fn from_args() -> io::Result<Self> {
let mut args = env::args().skip(1);
let mut positional = Vec::new();
let mut transport = TransportConfig::Tcp;
let mut read_timeout_secs = DEFAULT_READ_TIMEOUT_SECS;
let mut default_poll_secs = DEFAULT_POLL_INTERVAL_SECS;
let mut keep_after_error = false;
while let Some(arg) = args.next() {
match arg.as_str() {
"--tls" => {
if matches!(transport, TransportConfig::Tcp) {
transport = TransportConfig::Tls(TlsConfig::default());
}
}
"--ca-cert" => {
let path = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--ca-cert requires a path")
})?;
ensure_tls_config(&mut transport)?.ca_cert = Some(PathBuf::from(path));
}
"--client-cert" => {
let path = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--client-cert requires a path")
})?;
ensure_tls_config(&mut transport)?.client_cert = Some(PathBuf::from(path));
}
"--client-key" => {
let path = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--client-key requires a path")
})?;
ensure_tls_config(&mut transport)?.client_key = Some(PathBuf::from(path));
}
"--server-name" => {
let name = args.next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"--server-name requires a value",
)
})?;
ensure_tls_config(&mut transport)?.server_name = Some(name);
}
"--timeout" => {
let secs = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--timeout requires seconds")
})?;
read_timeout_secs = parse_u64_arg(&secs, "--timeout")?;
}
"--poll" => {
let secs = args.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "--poll requires seconds")
})?;
default_poll_secs = parse_u64_arg(&secs, "--poll")?;
}
"--keep-after-error" => {
keep_after_error = true;
}
_ if arg.starts_with("--") => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown option '{}'", arg),
));
}
_ => positional.push(arg),
}
}
let mut positional = positional.into_iter();
let addr = positional
.next()
.unwrap_or_else(|| "127.0.0.1:323".to_string());
let version = positional
.next()
.map(|s| {
s.parse::<u8>().map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid version '{}': {}", s, e),
)
})
})
.transpose()?
.unwrap_or(1);
if version > 2 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unsupported RTR version {}, expected 0..=2", version),
));
}
let mode = match positional.next().as_deref() {
None | Some("reset") => QueryMode::Reset,
Some("serial") => {
let session_id = positional
.next()
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"serial mode requires session_id and serial",
)
})?
.parse::<u16>()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid session_id: {}", e),
)
})?;
let serial = positional
.next()
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "serial mode requires serial")
})?
.parse::<u32>()
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid serial: {}", e),
)
})?;
QueryMode::Serial { session_id, serial }
}
Some(other) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid mode '{}', expected 'reset' or 'serial'", other),
));
}
};
let transport = finalize_transport(transport, &addr)?;
Ok(Self {
addr,
version,
mode,
read_timeout_secs,
default_poll_secs,
transport,
keep_after_error,
})
}
}
#[derive(Debug, Clone)]
enum TransportConfig {
Tcp,
Tls(TlsConfig),
}
impl TransportConfig {
fn describe(&self) -> String {
match self {
Self::Tcp => "tcp".to_string(),
Self::Tls(cfg) => format!(
"tls (server_name={}, ca_cert={}, client_cert={})",
cfg.server_name.as_deref().unwrap_or("<unset>"),
cfg.ca_cert
.as_ref()
.map(|path| path.display().to_string())
.unwrap_or_else(|| "<unset>".to_string()),
cfg.client_cert
.as_ref()
.map(|path| path.display().to_string())
.unwrap_or_else(|| "<none>".to_string())
),
}
}
}
#[derive(Debug, Clone, Default)]
struct TlsConfig {
server_name: Option<String>,
ca_cert: Option<PathBuf>,
client_cert: Option<PathBuf>,
client_key: Option<PathBuf>,
}
fn ensure_tls_config(transport: &mut TransportConfig) -> io::Result<&mut TlsConfig> {
if matches!(transport, TransportConfig::Tcp) {
*transport = TransportConfig::Tls(TlsConfig::default());
}
match transport {
TransportConfig::Tls(cfg) => Ok(cfg),
TransportConfig::Tcp => unreachable!(),
}
}
fn finalize_transport(transport: TransportConfig, addr: &str) -> io::Result<TransportConfig> {
match transport {
TransportConfig::Tcp => Ok(TransportConfig::Tcp),
TransportConfig::Tls(mut cfg) => {
let ca_cert = cfg.ca_cert.take().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"TLS mode requires --ca-cert <path>",
)
})?;
match (&cfg.client_cert, &cfg.client_key) {
(Some(_), Some(_)) | (None, None) => {}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS client authentication requires both --client-cert and --client-key",
));
}
}
let server_name = cfg
.server_name
.take()
.or_else(|| default_server_name_for_addr(addr))
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"TLS mode requires --server-name or an address with a parsable host",
)
})?;
Ok(TransportConfig::Tls(TlsConfig {
server_name: Some(server_name),
ca_cert: Some(ca_cert),
client_cert: cfg.client_cert,
client_key: cfg.client_key,
}))
}
}
}
async fn connect_stream(config: &Config) -> io::Result<DynStream> {
match &config.transport {
TransportConfig::Tcp => Ok(Box::new(TcpStream::connect(&config.addr).await?)),
TransportConfig::Tls(tls) => connect_tls_stream(&config.addr, tls).await,
}
}
async fn connect_tls_stream(addr: &str, tls: &TlsConfig) -> io::Result<DynStream> {
let stream = TcpStream::connect(addr).await?;
let connector = build_tls_connector(tls)?;
let server_name_str = tls
.server_name
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS server name"))?;
let server_name = ServerName::try_from(server_name_str.clone()).map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid TLS server name '{}': {}", server_name_str, err),
)
})?;
let tls_stream = connector
.connect(server_name, stream)
.await
.map_err(|err| {
io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("TLS handshake failed: {}", err),
)
})?;
Ok(Box::new(tls_stream))
}
fn build_tls_connector(tls: &TlsConfig) -> io::Result<TlsConnector> {
let ca_cert_path = tls
.ca_cert
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS CA cert"))?;
let ca_certs = load_certs(ca_cert_path)?;
let mut roots = RootCertStore::empty();
let (added, _ignored) = roots.add_parsable_certificates(ca_certs);
if added == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"no valid CA certificates found in {}",
ca_cert_path.display()
),
));
}
let builder = RustlsClientConfig::builder().with_root_certificates(roots);
let client_config = match (&tls.client_cert, &tls.client_key) {
(Some(cert_path), Some(key_path)) => {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
builder.with_client_auth_cert(certs, key).map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid TLS client certificate/key: {}", err),
)
})?
}
(None, None) => builder.with_no_client_auth(),
_ => unreachable!(),
};
Ok(TlsConnector::from(Arc::new(client_config)))
}
fn load_certs(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
let mut reader = std::io::BufReader::new(std::fs::File::open(path)?);
let certs = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
if certs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("no certificates found in {}", path.display()),
));
}
Ok(certs)
}
fn load_private_key(path: &Path) -> io::Result<PrivateKeyDer<'static>> {
let mut reader = std::io::BufReader::new(std::fs::File::open(path)?);
rustls_pemfile::private_key(&mut reader)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("no private key found in {}", path.display()),
)
})
}
fn default_server_name_for_addr(addr: &str) -> Option<String> {
if let Some(rest) = addr.strip_prefix('[') {
return rest.split(']').next().map(str::to_string);
}
addr.rsplit_once(':').map(|(host, _port)| host.to_string())
}
fn parse_u64_arg(value: &str, name: &str) -> io::Result<u64> {
let parsed = value.parse::<u64>().map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid value for {}: {}", name, err),
)
})?;
if parsed == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("{} must be > 0", name),
));
}
Ok(parsed)
}