1680 lines
58 KiB
Rust
1680 lines
58 KiB
Rust
use std::env;
|
|
use std::future::pending;
|
|
use std::io;
|
|
use std::path::{Path, PathBuf};
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::task::{Context, Poll};
|
|
|
|
use russh::client;
|
|
use russh::keys::{
|
|
PrivateKeyWithHashAlg, check_known_hosts_path, load_public_key, load_secret_key,
|
|
};
|
|
use russh::{ChannelStream, client::Msg as SshClientMsg};
|
|
use rustls::{ClientConfig as RustlsClientConfig, RootCertStore};
|
|
use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
|
|
use tokio::io::{
|
|
self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, ReadBuf, WriteHalf,
|
|
};
|
|
use tokio::net::TcpStream;
|
|
use tokio::time::{Duration, Instant, timeout};
|
|
use tokio_rustls::TlsConnector;
|
|
|
|
mod pretty;
|
|
mod protocol;
|
|
mod wire;
|
|
|
|
use crate::pretty::{parse_end_of_data_info, parse_serial_notify_serial, print_pdu, print_raw_pdu};
|
|
use crate::protocol::{PduHeader, PduType, QueryMode};
|
|
use crate::wire::{read_pdu, send_reset_query, send_serial_query};
|
|
|
|
macro_rules! println {
|
|
() => {
|
|
::std::println!();
|
|
};
|
|
($($arg:tt)*) => {{
|
|
let ts = chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%.3f%:z");
|
|
::std::println!("[{}] {}", ts, format_args!($($arg)*));
|
|
}};
|
|
}
|
|
|
|
macro_rules! eprintln {
|
|
() => {
|
|
::std::eprintln!();
|
|
};
|
|
($($arg:tt)*) => {{
|
|
let ts = chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%.3f%:z");
|
|
::std::eprintln!("[{}] {}", ts, format_args!($($arg)*));
|
|
}};
|
|
}
|
|
|
|
const DEFAULT_READ_TIMEOUT_SECS: u64 = 30;
|
|
const DEFAULT_POLL_INTERVAL_SECS: u64 = 600;
|
|
const DEFAULT_SSH_SUBSYSTEM_NAME: &str = "rpki-rtr";
|
|
|
|
trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
|
|
impl<T> AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
|
|
|
|
type DynStream = Box<dyn AsyncStream>;
|
|
type ClientWriter = WriteHalf<DynStream>;
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
enum OutputMode {
|
|
Verbose,
|
|
SummaryOnly,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> io::Result<()> {
|
|
let config = Config::from_args()?;
|
|
|
|
println!("== RTR debug client ==");
|
|
println!("target : {}", config.addr);
|
|
println!("transport: {}", config.transport.describe());
|
|
println!("version : {}", config.version);
|
|
println!("timeout : {}s", config.read_timeout_secs);
|
|
println!(
|
|
"poll : {}s (default before EndOfData refresh is known)",
|
|
config.default_poll_secs
|
|
);
|
|
println!("keep-after-error: {}", config.keep_after_error);
|
|
println!("output : {}", config.output_mode.describe());
|
|
match &config.mode {
|
|
QueryMode::Reset => {
|
|
println!("mode : reset");
|
|
}
|
|
QueryMode::Serial { session_id, serial } => {
|
|
println!("mode : serial");
|
|
println!("session : {}", session_id);
|
|
println!("serial : {}", serial);
|
|
}
|
|
}
|
|
println!();
|
|
if config.output_mode == OutputMode::Verbose {
|
|
print_help();
|
|
}
|
|
|
|
let mut state = ClientState::new(
|
|
config.version,
|
|
config.read_timeout_secs,
|
|
config.default_poll_secs,
|
|
config.keep_after_error,
|
|
config.output_mode,
|
|
);
|
|
|
|
let stdin = tokio::io::stdin();
|
|
let mut stdin_lines = BufReader::new(stdin).lines();
|
|
let mut stdin_closed = false;
|
|
|
|
loop {
|
|
let stream = loop {
|
|
match connect_stream(&config).await {
|
|
Ok(stream) => {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("connected to {}", config.addr);
|
|
}
|
|
break stream;
|
|
}
|
|
Err(err) => {
|
|
let delay = state.reconnect_delay_secs();
|
|
eprintln!("connect failed: {}. retry after {}s", err, delay);
|
|
tokio::time::sleep(Duration::from_secs(delay)).await;
|
|
}
|
|
}
|
|
};
|
|
|
|
let (mut reader, mut writer) = tokio_io::split(stream);
|
|
send_resume_query(&mut writer, &mut state, &config.mode).await?;
|
|
state.schedule_next_poll();
|
|
println!();
|
|
|
|
let reconnect = loop {
|
|
let poll_sleep = tokio::time::sleep_until(state.poll_deadline());
|
|
tokio::pin!(poll_sleep);
|
|
|
|
tokio::select! {
|
|
line = async {
|
|
if stdin_closed {
|
|
pending::<io::Result<Option<String>>>().await
|
|
} else {
|
|
stdin_lines.next_line().await
|
|
}
|
|
} => {
|
|
match line {
|
|
Ok(Some(line)) => {
|
|
match handle_console_command(
|
|
&line,
|
|
Some(&mut writer),
|
|
&mut state,
|
|
).await {
|
|
Ok(should_quit) => {
|
|
if should_quit {
|
|
println!("quit requested, closing client.");
|
|
return Ok(());
|
|
}
|
|
}
|
|
Err(err) if should_reconnect(&err) => {
|
|
eprintln!("command failed due to disconnected transport: {}", err);
|
|
break true;
|
|
}
|
|
Err(err) => return Err(err),
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
stdin_closed = true;
|
|
println!("stdin closed, disable console input.");
|
|
}
|
|
Err(err) => {
|
|
eprintln!("read stdin failed: {}", err);
|
|
}
|
|
}
|
|
}
|
|
|
|
_ = &mut poll_sleep => {
|
|
match handle_poll_tick(&mut writer, &mut state).await {
|
|
Ok(()) => state.schedule_next_poll(),
|
|
Err(err) if should_reconnect(&err) => {
|
|
eprintln!("auto poll failed due to disconnected transport: {}", err);
|
|
break true;
|
|
}
|
|
Err(err) => return Err(err),
|
|
}
|
|
}
|
|
|
|
read_result = timeout(
|
|
Duration::from_secs(state.read_timeout_secs),
|
|
read_pdu(&mut reader)
|
|
) => {
|
|
match read_result {
|
|
Ok(Ok(pdu)) => {
|
|
state.observe_pdu(&pdu.header);
|
|
if should_print_pdu(state.output_mode, &pdu.header) {
|
|
print_raw_pdu(&pdu.header, &pdu.body);
|
|
print_pdu(&pdu.header, &pdu.body);
|
|
}
|
|
match handle_incoming_pdu(&mut writer, &mut state, &pdu.header, &pdu.body).await {
|
|
Ok(()) => {}
|
|
Err(err) if should_reconnect(&err) => {
|
|
eprintln!("connection dropped while handling incoming PDU: {}", err);
|
|
break true;
|
|
}
|
|
Err(err) => return Err(err),
|
|
}
|
|
}
|
|
Ok(Err(err)) => {
|
|
eprintln!("read PDU failed: {}", err);
|
|
if should_reconnect(&err) {
|
|
break true;
|
|
}
|
|
return Err(err);
|
|
}
|
|
Err(_) => {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!(
|
|
"[timeout] no PDU received in {}s, connection kept open.",
|
|
state.read_timeout_secs
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
if reconnect {
|
|
let delay = state.reconnect_delay_secs();
|
|
state.current_session_id = None;
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("[reconnect] transport disconnected, retry after {}s", delay);
|
|
}
|
|
|
|
let reconnect_sleep = tokio::time::sleep(Duration::from_secs(delay));
|
|
tokio::pin!(reconnect_sleep);
|
|
let mut reconnect_now = false;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = &mut reconnect_sleep => break,
|
|
line = async {
|
|
if stdin_closed {
|
|
pending::<io::Result<Option<String>>>().await
|
|
} else {
|
|
stdin_lines.next_line().await
|
|
}
|
|
} => {
|
|
match line {
|
|
Ok(Some(line)) => {
|
|
match handle_console_command(&line, None, &mut state).await {
|
|
Ok(should_quit) => {
|
|
if should_quit {
|
|
println!("quit requested, closing client.");
|
|
return Ok(());
|
|
}
|
|
}
|
|
Ok(false) => {
|
|
if state.take_reconnect_now() {
|
|
reconnect_now = true;
|
|
break;
|
|
}
|
|
}
|
|
Err(err) => return Err(err),
|
|
}
|
|
}
|
|
Ok(None) => {
|
|
stdin_closed = true;
|
|
println!("stdin closed, disable console input.");
|
|
}
|
|
Err(err) => {
|
|
eprintln!("read stdin failed: {}", err);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if reconnect_now {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("[reconnect] user requested immediate reconnect");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn send_resume_query(
|
|
writer: &mut ClientWriter,
|
|
state: &mut ClientState,
|
|
mode: &QueryMode,
|
|
) -> io::Result<()> {
|
|
if state.force_reset_on_reconnect {
|
|
state.force_reset_on_reconnect = false;
|
|
state.session_id = None;
|
|
state.serial = None;
|
|
state.current_session_id = None;
|
|
send_reset_query(writer, state.version).await?;
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("reconnected, send Reset Query (forced)");
|
|
}
|
|
return Ok(());
|
|
}
|
|
|
|
match (state.session_id, state.serial) {
|
|
(Some(session_id), Some(serial)) => {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!(
|
|
"reconnected, send Serial Query with session_id={}, serial={}",
|
|
session_id, serial
|
|
);
|
|
}
|
|
send_serial_query(writer, state.version, session_id, serial).await?;
|
|
}
|
|
_ => match mode {
|
|
QueryMode::Reset => {
|
|
send_reset_query(writer, state.version).await?;
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("sent Reset Query");
|
|
}
|
|
}
|
|
QueryMode::Serial { session_id, serial } => {
|
|
state.session_id = Some(*session_id);
|
|
state.serial = Some(*serial);
|
|
send_serial_query(writer, state.version, *session_id, *serial).await?;
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("sent Serial Query");
|
|
}
|
|
}
|
|
},
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn should_reconnect(err: &io::Error) -> bool {
|
|
matches!(
|
|
err.kind(),
|
|
io::ErrorKind::UnexpectedEof
|
|
| io::ErrorKind::ConnectionAborted
|
|
| io::ErrorKind::ConnectionReset
|
|
| io::ErrorKind::BrokenPipe
|
|
| io::ErrorKind::NotConnected
|
|
)
|
|
}
|
|
|
|
async fn handle_incoming_pdu(
|
|
writer: &mut ClientWriter,
|
|
state: &mut ClientState,
|
|
header: &PduHeader,
|
|
body: &[u8],
|
|
) -> io::Result<()> {
|
|
match header.pdu_type() {
|
|
PduType::CacheResponse => {
|
|
state.current_session_id = Some(header.session_id());
|
|
state.last_error_code = None;
|
|
}
|
|
|
|
PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::RouterKey | PduType::Aspa => {
|
|
if state.current_session_id.is_none() {
|
|
state.current_session_id = Some(header.session_id());
|
|
}
|
|
}
|
|
|
|
PduType::EndOfData => {
|
|
let session_id = header.session_id();
|
|
let eod = parse_end_of_data_info(body);
|
|
|
|
state.session_id = Some(session_id);
|
|
state.current_session_id = Some(session_id);
|
|
|
|
if let Some(eod) = eod {
|
|
state.serial = Some(eod.serial);
|
|
state.refresh = eod.refresh;
|
|
state.retry = eod.retry;
|
|
state.expire = eod.expire;
|
|
state.last_error_code = None;
|
|
|
|
state.schedule_next_poll();
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!();
|
|
println!(
|
|
"updated client state: session_id={}, serial={}",
|
|
session_id, eod.serial
|
|
);
|
|
if let Some(refresh) = eod.refresh {
|
|
println!("refresh : {}", refresh);
|
|
}
|
|
if let Some(retry) = eod.retry {
|
|
println!("retry : {}", retry);
|
|
}
|
|
if let Some(expire) = eod.expire {
|
|
println!("expire : {}", expire);
|
|
}
|
|
println!(
|
|
"next auto poll scheduled after {}s",
|
|
state.effective_poll_secs()
|
|
);
|
|
} else {
|
|
println!(
|
|
"EndOfData: session_id={}, serial={}, next_poll={}s",
|
|
session_id,
|
|
eod.serial,
|
|
state.effective_poll_secs()
|
|
);
|
|
}
|
|
} else {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!();
|
|
println!(
|
|
"updated client state: session_id={}, serial=<unknown>",
|
|
session_id
|
|
);
|
|
} else {
|
|
println!("EndOfData: session_id={}, serial=<unknown>", session_id);
|
|
}
|
|
}
|
|
|
|
if state.output_mode == OutputMode::SummaryOnly
|
|
&& state.skipped_payload_pdu_count_in_round > 0
|
|
{
|
|
println!(
|
|
"summary : skipped {} payload PDUs in this response",
|
|
state.skipped_payload_pdu_count_in_round
|
|
);
|
|
state.skipped_payload_pdu_count_in_round = 0;
|
|
}
|
|
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("received EndOfData, keep connection open.");
|
|
println!();
|
|
}
|
|
}
|
|
|
|
PduType::SerialNotify => {
|
|
let notify_session_id = header.session_id();
|
|
let notify_serial = parse_serial_notify_serial(body);
|
|
|
|
println!();
|
|
println!(
|
|
"[notify] received Serial Notify: session_id={}, notify_serial={:?}",
|
|
notify_session_id, notify_serial
|
|
);
|
|
|
|
match (state.session_id, state.serial, notify_serial) {
|
|
(Some(current_session_id), Some(current_serial), Some(_new_serial))
|
|
if current_session_id == notify_session_id =>
|
|
{
|
|
println!(
|
|
"received Serial Notify for current session {}, send Serial Query with serial {}",
|
|
current_session_id, current_serial
|
|
);
|
|
send_serial_query(writer, state.version, current_session_id, current_serial)
|
|
.await?;
|
|
}
|
|
|
|
_ => {
|
|
println!(
|
|
"received Serial Notify but local session/serial state is not usable, send Reset Query"
|
|
);
|
|
send_reset_query(writer, state.version).await?;
|
|
}
|
|
}
|
|
|
|
state.schedule_next_poll();
|
|
println!();
|
|
}
|
|
|
|
PduType::CacheReset => {
|
|
println!();
|
|
println!("received Cache Reset, send Reset Query");
|
|
state.current_session_id = None;
|
|
state.serial = None;
|
|
state.last_error_code = None;
|
|
state.skipped_payload_pdu_count_in_round = 0;
|
|
send_reset_query(writer, state.version).await?;
|
|
state.schedule_next_poll();
|
|
println!();
|
|
}
|
|
|
|
PduType::ErrorReport => {
|
|
println!();
|
|
println!("received Error Report, pause auto polling for debugging.");
|
|
state.last_error_code = Some(header.error_code());
|
|
if let Some(retry) = state.retry {
|
|
println!("server retry hint currently stored: {}s", retry);
|
|
if state.should_prefer_retry_poll() {
|
|
println!("when resumed, auto polling will use retry instead of refresh.");
|
|
}
|
|
}
|
|
if state.keep_after_error {
|
|
println!("keep-after-error is enabled, auto polling will continue.");
|
|
state.schedule_next_poll();
|
|
} else {
|
|
println!("use `reset`, `serial`, or `poll resume` to continue manually.");
|
|
state.pause_auto_poll();
|
|
}
|
|
println!();
|
|
}
|
|
|
|
PduType::SerialQuery | PduType::ResetQuery | PduType::Unknown(_) => {
|
|
// only print, no extra action
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_poll_tick(writer: &mut ClientWriter, state: &mut ClientState) -> io::Result<()> {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!();
|
|
println!(
|
|
"[auto-poll] timer fired (interval={}s)",
|
|
state.effective_poll_secs()
|
|
);
|
|
}
|
|
|
|
match (state.session_id, state.serial) {
|
|
(Some(session_id), Some(serial)) => {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!(
|
|
"[auto-poll] send Serial Query with session_id={}, serial={}",
|
|
session_id, serial
|
|
);
|
|
}
|
|
send_serial_query(writer, state.version, session_id, serial).await?;
|
|
}
|
|
_ => {
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!("[auto-poll] local state incomplete, send Reset Query");
|
|
}
|
|
send_reset_query(writer, state.version).await?;
|
|
}
|
|
}
|
|
|
|
if state.output_mode == OutputMode::Verbose {
|
|
println!();
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_console_command(
|
|
line: &str,
|
|
mut writer: Option<&mut ClientWriter>,
|
|
state: &mut ClientState,
|
|
) -> io::Result<bool> {
|
|
let line = line.trim();
|
|
|
|
if line.is_empty() {
|
|
return Ok(false);
|
|
}
|
|
|
|
let parts: Vec<&str> = line.split_whitespace().collect();
|
|
|
|
match parts.as_slice() {
|
|
["help"] => {
|
|
print_help();
|
|
}
|
|
|
|
["state"] => {
|
|
print_state(state);
|
|
}
|
|
|
|
["version"] => {
|
|
println!("current RTR version: {}", state.version);
|
|
}
|
|
|
|
["version", version] => {
|
|
let version = match version.parse::<u8>() {
|
|
Ok(v) => v,
|
|
Err(err) => {
|
|
println!("invalid version: {}", err);
|
|
return Ok(false);
|
|
}
|
|
};
|
|
state.version = version;
|
|
println!("updated RTR version to {}", state.version);
|
|
}
|
|
|
|
["reset"] => {
|
|
println!("manual command: send Reset Query");
|
|
if let Some(writer) = writer.as_mut() {
|
|
send_reset_query(writer, state.version).await?;
|
|
state.schedule_next_poll();
|
|
} else {
|
|
state.force_reset_on_reconnect = true;
|
|
state.request_reconnect_now();
|
|
state.session_id = None;
|
|
state.serial = None;
|
|
state.current_session_id = None;
|
|
println!("not connected, queued Reset Query for next reconnect");
|
|
}
|
|
}
|
|
|
|
["serial"] => match (state.session_id, state.serial) {
|
|
(Some(session_id), Some(serial)) => {
|
|
println!(
|
|
"manual command: send Serial Query with current state: session_id={}, serial={}",
|
|
session_id, serial
|
|
);
|
|
if let Some(writer) = writer.as_mut() {
|
|
send_serial_query(writer, state.version, session_id, serial).await?;
|
|
state.schedule_next_poll();
|
|
} else {
|
|
println!("not connected, will send Serial Query on reconnect");
|
|
}
|
|
}
|
|
_ => {
|
|
println!(
|
|
"manual command failed: current session_id/serial not available, use `reset` or `serial <session_id> <serial>`"
|
|
);
|
|
}
|
|
},
|
|
|
|
["serial", session_id, serial] => {
|
|
let session_id = match session_id.parse::<u16>() {
|
|
Ok(v) => v,
|
|
Err(err) => {
|
|
println!("invalid session_id: {}", err);
|
|
return Ok(false);
|
|
}
|
|
};
|
|
|
|
let serial = match serial.parse::<u32>() {
|
|
Ok(v) => v,
|
|
Err(err) => {
|
|
println!("invalid serial: {}", err);
|
|
return Ok(false);
|
|
}
|
|
};
|
|
|
|
println!(
|
|
"manual command: send Serial Query with explicit args: session_id={}, serial={}",
|
|
session_id, serial
|
|
);
|
|
state.session_id = Some(session_id);
|
|
state.serial = Some(serial);
|
|
if let Some(writer) = writer.as_mut() {
|
|
send_serial_query(writer, state.version, session_id, serial).await?;
|
|
state.schedule_next_poll();
|
|
} else {
|
|
state.force_reset_on_reconnect = false;
|
|
println!("not connected, queued Serial Query for next reconnect");
|
|
}
|
|
}
|
|
|
|
["timeout"] => {
|
|
println!("current read timeout: {}s", state.read_timeout_secs);
|
|
}
|
|
|
|
["timeout", secs] => {
|
|
let secs = match secs.parse::<u64>() {
|
|
Ok(v) if v > 0 => v,
|
|
Ok(_) => {
|
|
println!("timeout must be > 0");
|
|
return Ok(false);
|
|
}
|
|
Err(err) => {
|
|
println!("invalid timeout seconds: {}", err);
|
|
return Ok(false);
|
|
}
|
|
};
|
|
|
|
state.read_timeout_secs = secs;
|
|
println!("updated read timeout to {}s", state.read_timeout_secs);
|
|
}
|
|
|
|
["poll"] => {
|
|
println!(
|
|
"current effective poll interval: {}s",
|
|
state.effective_poll_secs()
|
|
);
|
|
println!(
|
|
"poll interval source : {}",
|
|
state.poll_interval_source()
|
|
);
|
|
println!("stored refresh hint : {:?}", state.refresh);
|
|
println!("default poll interval : {}s", state.default_poll_secs);
|
|
println!("last_error_code : {:?}", state.last_error_code);
|
|
println!("auto polling paused : {}", state.poll_paused);
|
|
}
|
|
|
|
["poll", "pause"] => {
|
|
state.pause_auto_poll();
|
|
println!("auto polling paused");
|
|
}
|
|
|
|
["poll", "resume"] => {
|
|
state.resume_auto_poll();
|
|
println!(
|
|
"auto polling resumed, next poll scheduled after {}s",
|
|
state.effective_poll_secs()
|
|
);
|
|
}
|
|
|
|
["poll", secs] => {
|
|
let secs = match secs.parse::<u64>() {
|
|
Ok(v) if v > 0 => v,
|
|
Ok(_) => {
|
|
println!("poll interval must be > 0");
|
|
return Ok(false);
|
|
}
|
|
Err(err) => {
|
|
println!("invalid poll seconds: {}", err);
|
|
return Ok(false);
|
|
}
|
|
};
|
|
|
|
state.refresh = Some(secs as u32);
|
|
state.schedule_next_poll();
|
|
println!("updated poll interval to {}s", secs);
|
|
}
|
|
|
|
["quit"] | ["exit"] => {
|
|
return Ok(true);
|
|
}
|
|
|
|
["output"] => {
|
|
println!("current output mode: {}", state.output_mode.describe());
|
|
println!("skipped payload PDUs: {}", state.skipped_payload_pdu_count);
|
|
}
|
|
|
|
["output", "verbose"] => {
|
|
state.output_mode = OutputMode::Verbose;
|
|
println!("updated output mode to verbose");
|
|
}
|
|
|
|
["output", "summary"] => {
|
|
state.output_mode = OutputMode::SummaryOnly;
|
|
println!("updated output mode to summary");
|
|
}
|
|
|
|
_ => {
|
|
println!("unknown command: {}", line);
|
|
print_help();
|
|
}
|
|
}
|
|
|
|
Ok(false)
|
|
}
|
|
|
|
fn print_help() {
|
|
println!("available commands:");
|
|
println!(" help show this help");
|
|
println!(" state print current client state");
|
|
println!(" version show current RTR version");
|
|
println!(" version <n> update RTR version");
|
|
println!(" reset send Reset Query");
|
|
println!(" serial send Serial Query with current session_id/serial");
|
|
println!(" serial <sid> <serial> send Serial Query with explicit values");
|
|
println!(" timeout show current read timeout");
|
|
println!(" timeout <secs> update read timeout seconds");
|
|
println!(" poll show current poll interval");
|
|
println!(" poll <secs> override poll interval seconds");
|
|
println!(" poll pause pause auto polling");
|
|
println!(" poll resume resume auto polling");
|
|
println!(" keep-after-error show current keep-after-error setting");
|
|
println!(" output show current output mode");
|
|
println!(" output verbose print all PDUs");
|
|
println!(" output summary suppress payload PDU details");
|
|
println!(" quit exit client");
|
|
println!();
|
|
}
|
|
|
|
fn print_state(state: &ClientState) {
|
|
println!("client state:");
|
|
println!(" version : {}", state.version);
|
|
println!(" session_id : {:?}", state.session_id);
|
|
println!(" serial : {:?}", state.serial);
|
|
println!(" current_session_id : {:?}", state.current_session_id);
|
|
println!(" refresh : {:?}", state.refresh);
|
|
println!(" retry : {:?}", state.retry);
|
|
println!(" expire : {:?}", state.expire);
|
|
println!(" read_timeout_secs : {}", state.read_timeout_secs);
|
|
println!(" default_poll_secs : {}", state.default_poll_secs);
|
|
println!(" effective_poll_secs: {}", state.effective_poll_secs());
|
|
println!(" poll_source : {}", state.poll_interval_source());
|
|
println!(" last_error_code : {:?}", state.last_error_code);
|
|
println!(" keep_after_error : {}", state.keep_after_error);
|
|
println!(" output_mode : {}", state.output_mode.describe());
|
|
println!(" skipped_payloads : {}", state.skipped_payload_pdu_count);
|
|
println!(" poll_paused : {}", state.poll_paused);
|
|
println!();
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ClientState {
|
|
version: u8,
|
|
session_id: Option<u16>,
|
|
serial: Option<u32>,
|
|
current_session_id: Option<u16>,
|
|
|
|
refresh: Option<u32>,
|
|
retry: Option<u32>,
|
|
expire: Option<u32>,
|
|
last_error_code: Option<u16>,
|
|
keep_after_error: bool,
|
|
output_mode: OutputMode,
|
|
skipped_payload_pdu_count: u64,
|
|
skipped_payload_pdu_count_in_round: u64,
|
|
|
|
read_timeout_secs: u64,
|
|
default_poll_secs: u64,
|
|
next_poll_deadline: Instant,
|
|
poll_paused: bool,
|
|
force_reset_on_reconnect: bool,
|
|
reconnect_now: bool,
|
|
}
|
|
|
|
impl ClientState {
|
|
fn new(
|
|
version: u8,
|
|
read_timeout_secs: u64,
|
|
default_poll_secs: u64,
|
|
keep_after_error: bool,
|
|
output_mode: OutputMode,
|
|
) -> Self {
|
|
Self {
|
|
version,
|
|
session_id: None,
|
|
serial: None,
|
|
current_session_id: None,
|
|
refresh: None,
|
|
retry: None,
|
|
expire: None,
|
|
last_error_code: None,
|
|
keep_after_error,
|
|
output_mode,
|
|
skipped_payload_pdu_count: 0,
|
|
skipped_payload_pdu_count_in_round: 0,
|
|
read_timeout_secs,
|
|
default_poll_secs,
|
|
next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs),
|
|
poll_paused: false,
|
|
force_reset_on_reconnect: false,
|
|
reconnect_now: false,
|
|
}
|
|
}
|
|
|
|
fn effective_poll_secs(&self) -> u64 {
|
|
if self.should_prefer_retry_poll() {
|
|
self.retry.map(|v| v as u64).unwrap_or_else(|| {
|
|
self.refresh
|
|
.map(|v| v as u64)
|
|
.unwrap_or(self.default_poll_secs)
|
|
})
|
|
} else {
|
|
self.refresh
|
|
.map(|v| v as u64)
|
|
.unwrap_or(self.default_poll_secs)
|
|
}
|
|
}
|
|
|
|
fn schedule_next_poll(&mut self) {
|
|
self.next_poll_deadline = Instant::now() + Duration::from_secs(self.effective_poll_secs());
|
|
}
|
|
|
|
fn pause_auto_poll(&mut self) {
|
|
self.poll_paused = true;
|
|
}
|
|
|
|
fn resume_auto_poll(&mut self) {
|
|
self.poll_paused = false;
|
|
self.schedule_next_poll();
|
|
}
|
|
|
|
fn poll_deadline(&self) -> Instant {
|
|
if self.poll_paused {
|
|
Instant::now() + Duration::from_secs(365 * 24 * 60 * 60)
|
|
} else {
|
|
self.next_poll_deadline
|
|
}
|
|
}
|
|
|
|
fn should_prefer_retry_poll(&self) -> bool {
|
|
matches!(self.last_error_code, Some(2 | 10))
|
|
}
|
|
|
|
fn poll_interval_source(&self) -> &'static str {
|
|
if self.should_prefer_retry_poll() && self.retry.is_some() {
|
|
"retry"
|
|
} else if self.refresh.is_some() {
|
|
"refresh"
|
|
} else {
|
|
"default"
|
|
}
|
|
}
|
|
|
|
fn reconnect_delay_secs(&self) -> u64 {
|
|
if self.should_prefer_retry_poll() {
|
|
self.retry
|
|
.map(|v| v as u64)
|
|
.unwrap_or(self.default_poll_secs)
|
|
} else {
|
|
self.default_poll_secs
|
|
}
|
|
}
|
|
|
|
fn request_reconnect_now(&mut self) {
|
|
self.reconnect_now = true;
|
|
}
|
|
|
|
fn take_reconnect_now(&mut self) -> bool {
|
|
if self.reconnect_now {
|
|
self.reconnect_now = false;
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
fn observe_pdu(&mut self, header: &PduHeader) {
|
|
if self.output_mode == OutputMode::SummaryOnly && is_payload_pdu(header) {
|
|
self.skipped_payload_pdu_count = self.skipped_payload_pdu_count.saturating_add(1);
|
|
self.skipped_payload_pdu_count_in_round =
|
|
self.skipped_payload_pdu_count_in_round.saturating_add(1);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Config {
|
|
addr: String,
|
|
version: u8,
|
|
mode: QueryMode,
|
|
read_timeout_secs: u64,
|
|
default_poll_secs: u64,
|
|
transport: TransportConfig,
|
|
keep_after_error: bool,
|
|
output_mode: OutputMode,
|
|
}
|
|
|
|
impl Config {
|
|
fn from_args() -> io::Result<Self> {
|
|
let mut args = env::args().skip(1);
|
|
let mut positional = Vec::new();
|
|
let mut transport = TransportConfig::Tcp;
|
|
let mut read_timeout_secs = DEFAULT_READ_TIMEOUT_SECS;
|
|
let mut default_poll_secs = DEFAULT_POLL_INTERVAL_SECS;
|
|
let mut keep_after_error = false;
|
|
let mut output_mode = OutputMode::Verbose;
|
|
|
|
while let Some(arg) = args.next() {
|
|
match arg.as_str() {
|
|
"--tls" => match transport {
|
|
TransportConfig::Tcp => {
|
|
transport = TransportConfig::Tls(TlsConfig::default());
|
|
}
|
|
TransportConfig::Tls(_) => {}
|
|
TransportConfig::Ssh(_) => {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--tls cannot be used together with --ssh",
|
|
));
|
|
}
|
|
},
|
|
"--ssh" => match transport {
|
|
TransportConfig::Tcp => {
|
|
transport = TransportConfig::Ssh(SshConfig::default());
|
|
}
|
|
TransportConfig::Ssh(_) => {}
|
|
TransportConfig::Tls(_) => {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--ssh cannot be used together with --tls",
|
|
));
|
|
}
|
|
},
|
|
"--ca-cert" => {
|
|
let path = args.next().ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "--ca-cert requires a path")
|
|
})?;
|
|
ensure_tls_config(&mut transport)?.ca_cert = Some(PathBuf::from(path));
|
|
}
|
|
"--client-cert" => {
|
|
let path = args.next().ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "--client-cert requires a path")
|
|
})?;
|
|
ensure_tls_config(&mut transport)?.client_cert = Some(PathBuf::from(path));
|
|
}
|
|
"--client-key" => {
|
|
let path = args.next().ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "--client-key requires a path")
|
|
})?;
|
|
ensure_tls_config(&mut transport)?.client_key = Some(PathBuf::from(path));
|
|
}
|
|
"--server-name" => {
|
|
let name = args.next().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--server-name requires a value",
|
|
)
|
|
})?;
|
|
ensure_tls_config(&mut transport)?.server_name = Some(name);
|
|
}
|
|
"--ssh-user" => {
|
|
let user = args.next().ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "--ssh-user requires a value")
|
|
})?;
|
|
ensure_ssh_config(&mut transport)?.user = Some(user);
|
|
}
|
|
"--ssh-key" => {
|
|
let path = args.next().ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "--ssh-key requires a path")
|
|
})?;
|
|
ensure_ssh_config(&mut transport)?.private_key = Some(PathBuf::from(path));
|
|
}
|
|
"--ssh-subsystem" => {
|
|
let subsystem = args.next().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--ssh-subsystem requires a value",
|
|
)
|
|
})?;
|
|
ensure_ssh_config(&mut transport)?.subsystem = Some(subsystem);
|
|
}
|
|
"--ssh-known-hosts" => {
|
|
let path = args.next().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--ssh-known-hosts requires a path",
|
|
)
|
|
})?;
|
|
ensure_ssh_config(&mut transport)?.known_hosts = Some(PathBuf::from(path));
|
|
}
|
|
"--ssh-server-key" => {
|
|
let path = args.next().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--ssh-server-key requires a path",
|
|
)
|
|
})?;
|
|
ensure_ssh_config(&mut transport)?.server_key = Some(PathBuf::from(path));
|
|
}
|
|
"--timeout" => {
|
|
let secs = args.next().ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "--timeout requires seconds")
|
|
})?;
|
|
read_timeout_secs = parse_u64_arg(&secs, "--timeout")?;
|
|
}
|
|
"--poll" => {
|
|
let secs = args.next().ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "--poll requires seconds")
|
|
})?;
|
|
default_poll_secs = parse_u64_arg(&secs, "--poll")?;
|
|
}
|
|
"--keep-after-error" => {
|
|
keep_after_error = true;
|
|
}
|
|
"--summary-only" => {
|
|
output_mode = OutputMode::SummaryOnly;
|
|
}
|
|
_ if arg.starts_with("--") => {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("unknown option '{}'", arg),
|
|
));
|
|
}
|
|
_ => positional.push(arg),
|
|
}
|
|
}
|
|
|
|
let mut positional = positional.into_iter();
|
|
|
|
let addr = positional
|
|
.next()
|
|
.unwrap_or_else(|| "127.0.0.1:323".to_string());
|
|
|
|
let version = positional
|
|
.next()
|
|
.map(|s| {
|
|
s.parse::<u8>().map_err(|e| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid version '{}': {}", s, e),
|
|
)
|
|
})
|
|
})
|
|
.transpose()?
|
|
.unwrap_or(1);
|
|
|
|
// Allow any version here; server will validate and respond.
|
|
|
|
let mode = match positional.next().as_deref() {
|
|
None | Some("reset") => QueryMode::Reset,
|
|
Some("serial") => {
|
|
let session_id = positional
|
|
.next()
|
|
.ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"serial mode requires session_id and serial",
|
|
)
|
|
})?
|
|
.parse::<u16>()
|
|
.map_err(|e| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid session_id: {}", e),
|
|
)
|
|
})?;
|
|
|
|
let serial = positional
|
|
.next()
|
|
.ok_or_else(|| {
|
|
io::Error::new(io::ErrorKind::InvalidInput, "serial mode requires serial")
|
|
})?
|
|
.parse::<u32>()
|
|
.map_err(|e| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid serial: {}", e),
|
|
)
|
|
})?;
|
|
|
|
QueryMode::Serial { session_id, serial }
|
|
}
|
|
Some(other) => {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid mode '{}', expected 'reset' or 'serial'", other),
|
|
));
|
|
}
|
|
};
|
|
|
|
let transport = finalize_transport(transport, &addr)?;
|
|
|
|
Ok(Self {
|
|
addr,
|
|
version,
|
|
mode,
|
|
read_timeout_secs,
|
|
default_poll_secs,
|
|
transport,
|
|
keep_after_error,
|
|
output_mode,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl OutputMode {
|
|
fn describe(self) -> &'static str {
|
|
match self {
|
|
Self::Verbose => "verbose",
|
|
Self::SummaryOnly => "summary",
|
|
}
|
|
}
|
|
}
|
|
|
|
fn is_payload_pdu(header: &PduHeader) -> bool {
|
|
matches!(
|
|
header.pdu_type(),
|
|
PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::RouterKey | PduType::Aspa
|
|
)
|
|
}
|
|
|
|
fn should_print_pdu(output_mode: OutputMode, header: &PduHeader) -> bool {
|
|
match output_mode {
|
|
OutputMode::Verbose => true,
|
|
OutputMode::SummaryOnly => !is_payload_pdu(header),
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum TransportConfig {
|
|
Tcp,
|
|
Tls(TlsConfig),
|
|
Ssh(SshConfig),
|
|
}
|
|
|
|
impl TransportConfig {
|
|
fn describe(&self) -> String {
|
|
match self {
|
|
Self::Tcp => "tcp".to_string(),
|
|
Self::Tls(cfg) => format!(
|
|
"tls (server_name={}, ca_cert={}, client_cert={})",
|
|
cfg.server_name.as_deref().unwrap_or("<unset>"),
|
|
cfg.ca_cert
|
|
.as_ref()
|
|
.map(|path| path.display().to_string())
|
|
.unwrap_or_else(|| "<unset>".to_string()),
|
|
cfg.client_cert
|
|
.as_ref()
|
|
.map(|path| path.display().to_string())
|
|
.unwrap_or_else(|| "<none>".to_string())
|
|
),
|
|
Self::Ssh(cfg) => format!(
|
|
"ssh (user={}, subsystem={}, host_key_check={})",
|
|
cfg.user.as_deref().unwrap_or("<unset>"),
|
|
cfg.subsystem
|
|
.as_deref()
|
|
.unwrap_or(DEFAULT_SSH_SUBSYSTEM_NAME),
|
|
cfg.host_key_verification
|
|
.as_ref()
|
|
.map(HostKeyVerification::describe)
|
|
.unwrap_or("<unset>")
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
struct TlsConfig {
|
|
server_name: Option<String>,
|
|
ca_cert: Option<PathBuf>,
|
|
client_cert: Option<PathBuf>,
|
|
client_key: Option<PathBuf>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum HostKeyVerification {
|
|
KnownHosts(PathBuf),
|
|
PinnedServerKey(PathBuf),
|
|
}
|
|
|
|
impl HostKeyVerification {
|
|
fn describe(&self) -> &'static str {
|
|
match self {
|
|
Self::KnownHosts(_) => "known_hosts",
|
|
Self::PinnedServerKey(_) => "pinned_server_key",
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
struct SshConfig {
|
|
user: Option<String>,
|
|
private_key: Option<PathBuf>,
|
|
subsystem: Option<String>,
|
|
known_hosts: Option<PathBuf>,
|
|
server_key: Option<PathBuf>,
|
|
host_key_verification: Option<HostKeyVerification>,
|
|
}
|
|
|
|
fn ensure_tls_config(transport: &mut TransportConfig) -> io::Result<&mut TlsConfig> {
|
|
match transport {
|
|
TransportConfig::Tcp => {
|
|
*transport = TransportConfig::Tls(TlsConfig::default());
|
|
match transport {
|
|
TransportConfig::Tls(cfg) => Ok(cfg),
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
TransportConfig::Tls(cfg) => Ok(cfg),
|
|
TransportConfig::Ssh(_) => Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"TLS options cannot be used together with --ssh",
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn ensure_ssh_config(transport: &mut TransportConfig) -> io::Result<&mut SshConfig> {
|
|
match transport {
|
|
TransportConfig::Tcp => {
|
|
*transport = TransportConfig::Ssh(SshConfig::default());
|
|
match transport {
|
|
TransportConfig::Ssh(cfg) => Ok(cfg),
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
TransportConfig::Ssh(cfg) => Ok(cfg),
|
|
TransportConfig::Tls(_) => Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"SSH options cannot be used together with --tls",
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn finalize_transport(transport: TransportConfig, addr: &str) -> io::Result<TransportConfig> {
|
|
match transport {
|
|
TransportConfig::Tcp => Ok(TransportConfig::Tcp),
|
|
TransportConfig::Tls(mut cfg) => {
|
|
let ca_cert = cfg.ca_cert.take().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"TLS mode requires --ca-cert <path>",
|
|
)
|
|
})?;
|
|
|
|
match (&cfg.client_cert, &cfg.client_key) {
|
|
(Some(_), Some(_)) | (None, None) => {}
|
|
_ => {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"TLS client authentication requires both --client-cert and --client-key",
|
|
));
|
|
}
|
|
}
|
|
|
|
let server_name = cfg
|
|
.server_name
|
|
.take()
|
|
.or_else(|| default_server_name_for_addr(addr))
|
|
.ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"TLS mode requires --server-name or an address with a parsable host",
|
|
)
|
|
})?;
|
|
|
|
Ok(TransportConfig::Tls(TlsConfig {
|
|
server_name: Some(server_name),
|
|
ca_cert: Some(ca_cert),
|
|
client_cert: cfg.client_cert,
|
|
client_key: cfg.client_key,
|
|
}))
|
|
}
|
|
TransportConfig::Ssh(mut cfg) => {
|
|
let user = cfg.user.take().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"SSH mode requires --ssh-user <name>",
|
|
)
|
|
})?;
|
|
if user.trim().is_empty() {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--ssh-user must not be empty",
|
|
));
|
|
}
|
|
|
|
let private_key = cfg.private_key.take().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"SSH mode requires --ssh-key <path>",
|
|
)
|
|
})?;
|
|
|
|
if cfg.known_hosts.is_some() && cfg.server_key.is_some() {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"SSH host key verification must choose one: --ssh-known-hosts or --ssh-server-key",
|
|
));
|
|
}
|
|
|
|
let host_key_verification = if let Some(path) = cfg.known_hosts.take() {
|
|
HostKeyVerification::KnownHosts(path)
|
|
} else if let Some(path) = cfg.server_key.take() {
|
|
HostKeyVerification::PinnedServerKey(path)
|
|
} else {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"SSH mode requires host key verification: --ssh-known-hosts <path> or --ssh-server-key <path>",
|
|
));
|
|
};
|
|
|
|
let subsystem = cfg
|
|
.subsystem
|
|
.take()
|
|
.unwrap_or_else(|| DEFAULT_SSH_SUBSYSTEM_NAME.to_string());
|
|
|
|
if subsystem.trim().is_empty() {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"--ssh-subsystem must not be empty",
|
|
));
|
|
}
|
|
|
|
let _ = parse_host_port(addr)?;
|
|
|
|
Ok(TransportConfig::Ssh(SshConfig {
|
|
user: Some(user),
|
|
private_key: Some(private_key),
|
|
subsystem: Some(subsystem),
|
|
known_hosts: None,
|
|
server_key: None,
|
|
host_key_verification: Some(host_key_verification),
|
|
}))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn connect_stream(config: &Config) -> io::Result<DynStream> {
|
|
match &config.transport {
|
|
TransportConfig::Tcp => Ok(Box::new(TcpStream::connect(&config.addr).await?)),
|
|
TransportConfig::Tls(tls) => connect_tls_stream(&config.addr, tls).await,
|
|
TransportConfig::Ssh(ssh) => connect_ssh_stream(&config.addr, ssh).await,
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct SshClientHandler {
|
|
host: String,
|
|
port: u16,
|
|
host_key_verification: HostKeyVerification,
|
|
}
|
|
|
|
impl client::Handler for SshClientHandler {
|
|
type Error = russh::Error;
|
|
|
|
async fn check_server_key(
|
|
&mut self,
|
|
server_public_key: &russh::keys::ssh_key::PublicKey,
|
|
) -> Result<bool, Self::Error> {
|
|
match &self.host_key_verification {
|
|
HostKeyVerification::KnownHosts(path) => {
|
|
check_known_hosts_path(&self.host, self.port, server_public_key, path)
|
|
.map_err(Into::into)
|
|
}
|
|
HostKeyVerification::PinnedServerKey(path) => {
|
|
let expected_key = load_public_key(path)?;
|
|
Ok(expected_key == *server_public_key)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct SshSessionStream {
|
|
channel_stream: ChannelStream<SshClientMsg>,
|
|
_session: client::Handle<SshClientHandler>,
|
|
}
|
|
|
|
impl AsyncRead for SshSessionStream {
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut self.channel_stream).poll_read(cx, buf)
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for SshSessionStream {
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
Pin::new(&mut self.channel_stream).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut self.channel_stream).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
Pin::new(&mut self.channel_stream).poll_shutdown(cx)
|
|
}
|
|
}
|
|
|
|
async fn connect_ssh_stream(addr: &str, ssh: &SshConfig) -> io::Result<DynStream> {
|
|
let user = ssh
|
|
.user
|
|
.as_deref()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing SSH user"))?;
|
|
let private_key_path = ssh
|
|
.private_key
|
|
.as_ref()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing SSH private key"))?;
|
|
let subsystem = ssh
|
|
.subsystem
|
|
.as_deref()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing SSH subsystem"))?;
|
|
let host_key_verification = ssh.host_key_verification.clone().ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
"missing SSH host key verification",
|
|
)
|
|
})?;
|
|
|
|
let (host, port) = parse_host_port(addr)?;
|
|
let handler = SshClientHandler {
|
|
host,
|
|
port,
|
|
host_key_verification,
|
|
};
|
|
|
|
let session_config = Arc::new(client::Config::default());
|
|
let mut session = client::connect(session_config, addr, handler)
|
|
.await
|
|
.map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::ConnectionAborted,
|
|
format!("SSH handshake failed: {}", err),
|
|
)
|
|
})?;
|
|
|
|
let private_key = load_secret_key(private_key_path, None).map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!(
|
|
"failed to load SSH private key {}: {}",
|
|
private_key_path.display(),
|
|
err
|
|
),
|
|
)
|
|
})?;
|
|
let rsa_hash = session.best_supported_rsa_hash().await.map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::ConnectionAborted,
|
|
format!("failed to negotiate SSH RSA hash: {}", err),
|
|
)
|
|
})?;
|
|
let auth_result = session
|
|
.authenticate_publickey(
|
|
user.to_string(),
|
|
PrivateKeyWithHashAlg::new(Arc::new(private_key), rsa_hash.flatten()),
|
|
)
|
|
.await
|
|
.map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::PermissionDenied,
|
|
format!("SSH publickey authentication failed: {}", err),
|
|
)
|
|
})?;
|
|
if !auth_result.success() {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::PermissionDenied,
|
|
"SSH publickey authentication rejected by server",
|
|
));
|
|
}
|
|
|
|
let channel = session.channel_open_session().await.map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::ConnectionAborted,
|
|
format!("failed to open SSH session channel: {}", err),
|
|
)
|
|
})?;
|
|
channel
|
|
.request_subsystem(true, subsystem)
|
|
.await
|
|
.map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::ConnectionAborted,
|
|
format!("failed to request SSH subsystem '{}': {}", subsystem, err),
|
|
)
|
|
})?;
|
|
let channel_stream = channel.into_stream();
|
|
|
|
Ok(Box::new(SshSessionStream {
|
|
channel_stream,
|
|
_session: session,
|
|
}))
|
|
}
|
|
|
|
async fn connect_tls_stream(addr: &str, tls: &TlsConfig) -> io::Result<DynStream> {
|
|
let stream = TcpStream::connect(addr).await?;
|
|
let connector = build_tls_connector(tls)?;
|
|
let server_name_str = tls
|
|
.server_name
|
|
.as_ref()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS server name"))?;
|
|
let server_name = ServerName::try_from(server_name_str.clone()).map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid TLS server name '{}': {}", server_name_str, err),
|
|
)
|
|
})?;
|
|
let tls_stream = connector
|
|
.connect(server_name, stream)
|
|
.await
|
|
.map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::ConnectionAborted,
|
|
format!("TLS handshake failed: {}", err),
|
|
)
|
|
})?;
|
|
Ok(Box::new(tls_stream))
|
|
}
|
|
|
|
fn build_tls_connector(tls: &TlsConfig) -> io::Result<TlsConnector> {
|
|
let ca_cert_path = tls
|
|
.ca_cert
|
|
.as_ref()
|
|
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS CA cert"))?;
|
|
let ca_certs = load_certs(ca_cert_path)?;
|
|
let mut roots = RootCertStore::empty();
|
|
let (added, _ignored) = roots.add_parsable_certificates(ca_certs);
|
|
if added == 0 {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!(
|
|
"no valid CA certificates found in {}",
|
|
ca_cert_path.display()
|
|
),
|
|
));
|
|
}
|
|
|
|
let builder = RustlsClientConfig::builder().with_root_certificates(roots);
|
|
let client_config = match (&tls.client_cert, &tls.client_key) {
|
|
(Some(cert_path), Some(key_path)) => {
|
|
let certs = load_certs(cert_path)?;
|
|
let key = load_private_key(key_path)?;
|
|
builder.with_client_auth_cert(certs, key).map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid TLS client certificate/key: {}", err),
|
|
)
|
|
})?
|
|
}
|
|
(None, None) => builder.with_no_client_auth(),
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
Ok(TlsConnector::from(Arc::new(client_config)))
|
|
}
|
|
|
|
fn load_certs(path: &Path) -> io::Result<Vec<CertificateDer<'static>>> {
|
|
let mut reader = std::io::BufReader::new(std::fs::File::open(path)?);
|
|
let certs = rustls_pemfile::certs(&mut reader)
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
|
|
if certs.is_empty() {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidData,
|
|
format!("no certificates found in {}", path.display()),
|
|
));
|
|
}
|
|
Ok(certs)
|
|
}
|
|
|
|
fn load_private_key(path: &Path) -> io::Result<PrivateKeyDer<'static>> {
|
|
let mut reader = std::io::BufReader::new(std::fs::File::open(path)?);
|
|
rustls_pemfile::private_key(&mut reader)
|
|
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
|
|
.ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidData,
|
|
format!("no private key found in {}", path.display()),
|
|
)
|
|
})
|
|
}
|
|
|
|
fn default_server_name_for_addr(addr: &str) -> Option<String> {
|
|
if let Some(rest) = addr.strip_prefix('[') {
|
|
return rest.split(']').next().map(str::to_string);
|
|
}
|
|
addr.rsplit_once(':').map(|(host, _port)| host.to_string())
|
|
}
|
|
|
|
fn parse_host_port(addr: &str) -> io::Result<(String, u16)> {
|
|
if let Some(rest) = addr.strip_prefix('[') {
|
|
let (host, port_part) = rest.split_once("]:").ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid address '{}', expected [host]:port", addr),
|
|
)
|
|
})?;
|
|
let port = port_part.parse::<u16>().map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid port in address '{}': {}", addr, err),
|
|
)
|
|
})?;
|
|
return Ok((host.to_string(), port));
|
|
}
|
|
|
|
let (host, port_part) = addr.rsplit_once(':').ok_or_else(|| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid address '{}', expected host:port", addr),
|
|
)
|
|
})?;
|
|
if host.is_empty() {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid address '{}', host must not be empty", addr),
|
|
));
|
|
}
|
|
let port = port_part.parse::<u16>().map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid port in address '{}': {}", addr, err),
|
|
)
|
|
})?;
|
|
Ok((host.to_string(), port))
|
|
}
|
|
|
|
fn parse_u64_arg(value: &str, name: &str) -> io::Result<u64> {
|
|
let parsed = value.parse::<u64>().map_err(|err| {
|
|
io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("invalid value for {}: {}", name, err),
|
|
)
|
|
})?;
|
|
if parsed == 0 {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::InvalidInput,
|
|
format!("{} must be > 0", name),
|
|
));
|
|
}
|
|
Ok(parsed)
|
|
}
|