use std::env; use std::future::pending; use std::io; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use russh::client; use russh::keys::{ PrivateKeyWithHashAlg, check_known_hosts_path, load_public_key, load_secret_key, }; use russh::{ChannelStream, client::Msg as SshClientMsg}; use rustls::{ClientConfig as RustlsClientConfig, RootCertStore}; use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use tokio::io::{ self as tokio_io, AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader, ReadBuf, WriteHalf, }; use tokio::net::TcpStream; use tokio::time::{Duration, Instant, timeout}; use tokio_rustls::TlsConnector; mod pretty; mod protocol; mod wire; use crate::pretty::{parse_end_of_data_info, parse_serial_notify_serial, print_pdu, print_raw_pdu}; use crate::protocol::{PduHeader, PduType, QueryMode}; use crate::wire::{read_pdu, send_reset_query, send_serial_query}; macro_rules! println { () => { ::std::println!(); }; ($($arg:tt)*) => {{ let ts = chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%.3f%:z"); ::std::println!("[{}] {}", ts, format_args!($($arg)*)); }}; } macro_rules! eprintln { () => { ::std::eprintln!(); }; ($($arg:tt)*) => {{ let ts = chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%.3f%:z"); ::std::eprintln!("[{}] {}", ts, format_args!($($arg)*)); }}; } const DEFAULT_READ_TIMEOUT_SECS: u64 = 30; const DEFAULT_POLL_INTERVAL_SECS: u64 = 600; const DEFAULT_SSH_SUBSYSTEM_NAME: &str = "rpki-rtr"; trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {} impl AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {} type DynStream = Box; type ClientWriter = WriteHalf; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum OutputMode { Verbose, SummaryOnly, } #[tokio::main] async fn main() -> io::Result<()> { let config = Config::from_args()?; println!("== RTR debug client =="); println!("target : {}", config.addr); println!("transport: {}", config.transport.describe()); println!("version : {}", config.version); println!("timeout : {}s", config.read_timeout_secs); println!( "poll : {}s (default before EndOfData refresh is known)", config.default_poll_secs ); println!("keep-after-error: {}", config.keep_after_error); println!("output : {}", config.output_mode.describe()); match &config.mode { QueryMode::Reset => { println!("mode : reset"); } QueryMode::Serial { session_id, serial } => { println!("mode : serial"); println!("session : {}", session_id); println!("serial : {}", serial); } } println!(); if config.output_mode == OutputMode::Verbose { print_help(); } let mut state = ClientState::new( config.version, config.read_timeout_secs, config.default_poll_secs, config.keep_after_error, config.output_mode, ); let stdin = tokio::io::stdin(); let mut stdin_lines = BufReader::new(stdin).lines(); let mut stdin_closed = false; loop { let stream = loop { match connect_stream(&config).await { Ok(stream) => { if state.output_mode == OutputMode::Verbose { println!("connected to {}", config.addr); } break stream; } Err(err) => { let delay = state.reconnect_delay_secs(); eprintln!("connect failed: {}. retry after {}s", err, delay); tokio::time::sleep(Duration::from_secs(delay)).await; } } }; let (mut reader, mut writer) = tokio_io::split(stream); send_resume_query(&mut writer, &mut state, &config.mode).await?; state.schedule_next_poll(); println!(); let reconnect = loop { let poll_sleep = tokio::time::sleep_until(state.poll_deadline()); tokio::pin!(poll_sleep); tokio::select! { line = async { if stdin_closed { pending::>>().await } else { stdin_lines.next_line().await } } => { match line { Ok(Some(line)) => { match handle_console_command( &line, Some(&mut writer), &mut state, ).await { Ok(should_quit) => { if should_quit { println!("quit requested, closing client."); return Ok(()); } } Err(err) if should_reconnect(&err) => { eprintln!("command failed due to disconnected transport: {}", err); break true; } Err(err) => return Err(err), } } Ok(None) => { stdin_closed = true; println!("stdin closed, disable console input."); } Err(err) => { eprintln!("read stdin failed: {}", err); } } } _ = &mut poll_sleep => { match handle_poll_tick(&mut writer, &mut state).await { Ok(()) => state.schedule_next_poll(), Err(err) if should_reconnect(&err) => { eprintln!("auto poll failed due to disconnected transport: {}", err); break true; } Err(err) => return Err(err), } } read_result = timeout( Duration::from_secs(state.read_timeout_secs), read_pdu(&mut reader) ) => { match read_result { Ok(Ok(pdu)) => { state.observe_pdu(&pdu.header); if should_print_pdu(state.output_mode, &pdu.header) { print_raw_pdu(&pdu.header, &pdu.body); print_pdu(&pdu.header, &pdu.body); } match handle_incoming_pdu(&mut writer, &mut state, &pdu.header, &pdu.body).await { Ok(()) => {} Err(err) if should_reconnect(&err) => { eprintln!("connection dropped while handling incoming PDU: {}", err); break true; } Err(err) => return Err(err), } } Ok(Err(err)) => { eprintln!("read PDU failed: {}", err); if should_reconnect(&err) { break true; } return Err(err); } Err(_) => { if state.output_mode == OutputMode::Verbose { println!( "[timeout] no PDU received in {}s, connection kept open.", state.read_timeout_secs ); } } } } } }; if reconnect { let delay = state.reconnect_delay_secs(); state.current_session_id = None; if state.output_mode == OutputMode::Verbose { println!("[reconnect] transport disconnected, retry after {}s", delay); } let reconnect_sleep = tokio::time::sleep(Duration::from_secs(delay)); tokio::pin!(reconnect_sleep); let mut reconnect_now = false; loop { tokio::select! { _ = &mut reconnect_sleep => break, line = async { if stdin_closed { pending::>>().await } else { stdin_lines.next_line().await } } => { match line { Ok(Some(line)) => { match handle_console_command(&line, None, &mut state).await { Ok(should_quit) => { if should_quit { println!("quit requested, closing client."); return Ok(()); } } Ok(false) => { if state.take_reconnect_now() { reconnect_now = true; break; } } Err(err) => return Err(err), } } Ok(None) => { stdin_closed = true; println!("stdin closed, disable console input."); } Err(err) => { eprintln!("read stdin failed: {}", err); } } } } } if reconnect_now { if state.output_mode == OutputMode::Verbose { println!("[reconnect] user requested immediate reconnect"); } } } } } async fn send_resume_query( writer: &mut ClientWriter, state: &mut ClientState, mode: &QueryMode, ) -> io::Result<()> { if state.force_reset_on_reconnect { state.force_reset_on_reconnect = false; state.session_id = None; state.serial = None; state.current_session_id = None; send_reset_query(writer, state.version).await?; if state.output_mode == OutputMode::Verbose { println!("reconnected, send Reset Query (forced)"); } return Ok(()); } match (state.session_id, state.serial) { (Some(session_id), Some(serial)) => { if state.output_mode == OutputMode::Verbose { println!( "reconnected, send Serial Query with session_id={}, serial={}", session_id, serial ); } send_serial_query(writer, state.version, session_id, serial).await?; } _ => match mode { QueryMode::Reset => { send_reset_query(writer, state.version).await?; if state.output_mode == OutputMode::Verbose { println!("sent Reset Query"); } } QueryMode::Serial { session_id, serial } => { state.session_id = Some(*session_id); state.serial = Some(*serial); send_serial_query(writer, state.version, *session_id, *serial).await?; if state.output_mode == OutputMode::Verbose { println!("sent Serial Query"); } } }, } Ok(()) } fn should_reconnect(err: &io::Error) -> bool { matches!( err.kind(), io::ErrorKind::UnexpectedEof | io::ErrorKind::ConnectionAborted | io::ErrorKind::ConnectionReset | io::ErrorKind::BrokenPipe | io::ErrorKind::NotConnected ) } async fn handle_incoming_pdu( writer: &mut ClientWriter, state: &mut ClientState, header: &PduHeader, body: &[u8], ) -> io::Result<()> { match header.pdu_type() { PduType::CacheResponse => { state.current_session_id = Some(header.session_id()); state.last_error_code = None; } PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::RouterKey | PduType::Aspa => { if state.current_session_id.is_none() { state.current_session_id = Some(header.session_id()); } } PduType::EndOfData => { let session_id = header.session_id(); let eod = parse_end_of_data_info(body); state.session_id = Some(session_id); state.current_session_id = Some(session_id); if let Some(eod) = eod { state.serial = Some(eod.serial); state.refresh = eod.refresh; state.retry = eod.retry; state.expire = eod.expire; state.last_error_code = None; state.schedule_next_poll(); if state.output_mode == OutputMode::Verbose { println!(); println!( "updated client state: session_id={}, serial={}", session_id, eod.serial ); if let Some(refresh) = eod.refresh { println!("refresh : {}", refresh); } if let Some(retry) = eod.retry { println!("retry : {}", retry); } if let Some(expire) = eod.expire { println!("expire : {}", expire); } println!( "next auto poll scheduled after {}s", state.effective_poll_secs() ); } else { println!( "EndOfData: session_id={}, serial={}, next_poll={}s", session_id, eod.serial, state.effective_poll_secs() ); } } else { if state.output_mode == OutputMode::Verbose { println!(); println!( "updated client state: session_id={}, serial=", session_id ); } else { println!("EndOfData: session_id={}, serial=", session_id); } } if state.output_mode == OutputMode::SummaryOnly && state.skipped_payload_pdu_count_in_round > 0 { println!( "summary : skipped {} payload PDUs in this response", state.skipped_payload_pdu_count_in_round ); state.skipped_payload_pdu_count_in_round = 0; } if state.output_mode == OutputMode::Verbose { println!("received EndOfData, keep connection open."); println!(); } } PduType::SerialNotify => { let notify_session_id = header.session_id(); let notify_serial = parse_serial_notify_serial(body); println!(); println!( "[notify] received Serial Notify: session_id={}, notify_serial={:?}", notify_session_id, notify_serial ); match (state.session_id, state.serial, notify_serial) { (Some(current_session_id), Some(current_serial), Some(_new_serial)) if current_session_id == notify_session_id => { println!( "received Serial Notify for current session {}, send Serial Query with serial {}", current_session_id, current_serial ); send_serial_query(writer, state.version, current_session_id, current_serial) .await?; } _ => { println!( "received Serial Notify but local session/serial state is not usable, send Reset Query" ); send_reset_query(writer, state.version).await?; } } state.schedule_next_poll(); println!(); } PduType::CacheReset => { println!(); println!("received Cache Reset, send Reset Query"); state.current_session_id = None; state.serial = None; state.last_error_code = None; state.skipped_payload_pdu_count_in_round = 0; send_reset_query(writer, state.version).await?; state.schedule_next_poll(); println!(); } PduType::ErrorReport => { println!(); println!("received Error Report, pause auto polling for debugging."); state.last_error_code = Some(header.error_code()); if let Some(retry) = state.retry { println!("server retry hint currently stored: {}s", retry); if state.should_prefer_retry_poll() { println!("when resumed, auto polling will use retry instead of refresh."); } } if state.keep_after_error { println!("keep-after-error is enabled, auto polling will continue."); state.schedule_next_poll(); } else { println!("use `reset`, `serial`, or `poll resume` to continue manually."); state.pause_auto_poll(); } println!(); } PduType::SerialQuery | PduType::ResetQuery | PduType::Unknown(_) => { // only print, no extra action } } Ok(()) } async fn handle_poll_tick(writer: &mut ClientWriter, state: &mut ClientState) -> io::Result<()> { if state.output_mode == OutputMode::Verbose { println!(); println!( "[auto-poll] timer fired (interval={}s)", state.effective_poll_secs() ); } match (state.session_id, state.serial) { (Some(session_id), Some(serial)) => { if state.output_mode == OutputMode::Verbose { println!( "[auto-poll] send Serial Query with session_id={}, serial={}", session_id, serial ); } send_serial_query(writer, state.version, session_id, serial).await?; } _ => { if state.output_mode == OutputMode::Verbose { println!("[auto-poll] local state incomplete, send Reset Query"); } send_reset_query(writer, state.version).await?; } } if state.output_mode == OutputMode::Verbose { println!(); } Ok(()) } async fn handle_console_command( line: &str, mut writer: Option<&mut ClientWriter>, state: &mut ClientState, ) -> io::Result { let line = line.trim(); if line.is_empty() { return Ok(false); } let parts: Vec<&str> = line.split_whitespace().collect(); match parts.as_slice() { ["help"] => { print_help(); } ["state"] => { print_state(state); } ["version"] => { println!("current RTR version: {}", state.version); } ["version", version] => { let version = match version.parse::() { Ok(v) => v, Err(err) => { println!("invalid version: {}", err); return Ok(false); } }; state.version = version; println!("updated RTR version to {}", state.version); } ["reset"] => { println!("manual command: send Reset Query"); if let Some(writer) = writer.as_mut() { send_reset_query(writer, state.version).await?; state.schedule_next_poll(); } else { state.force_reset_on_reconnect = true; state.request_reconnect_now(); state.session_id = None; state.serial = None; state.current_session_id = None; println!("not connected, queued Reset Query for next reconnect"); } } ["serial"] => match (state.session_id, state.serial) { (Some(session_id), Some(serial)) => { println!( "manual command: send Serial Query with current state: session_id={}, serial={}", session_id, serial ); if let Some(writer) = writer.as_mut() { send_serial_query(writer, state.version, session_id, serial).await?; state.schedule_next_poll(); } else { println!("not connected, will send Serial Query on reconnect"); } } _ => { println!( "manual command failed: current session_id/serial not available, use `reset` or `serial `" ); } }, ["serial", session_id, serial] => { let session_id = match session_id.parse::() { Ok(v) => v, Err(err) => { println!("invalid session_id: {}", err); return Ok(false); } }; let serial = match serial.parse::() { Ok(v) => v, Err(err) => { println!("invalid serial: {}", err); return Ok(false); } }; println!( "manual command: send Serial Query with explicit args: session_id={}, serial={}", session_id, serial ); state.session_id = Some(session_id); state.serial = Some(serial); if let Some(writer) = writer.as_mut() { send_serial_query(writer, state.version, session_id, serial).await?; state.schedule_next_poll(); } else { state.force_reset_on_reconnect = false; println!("not connected, queued Serial Query for next reconnect"); } } ["timeout"] => { println!("current read timeout: {}s", state.read_timeout_secs); } ["timeout", secs] => { let secs = match secs.parse::() { Ok(v) if v > 0 => v, Ok(_) => { println!("timeout must be > 0"); return Ok(false); } Err(err) => { println!("invalid timeout seconds: {}", err); return Ok(false); } }; state.read_timeout_secs = secs; println!("updated read timeout to {}s", state.read_timeout_secs); } ["poll"] => { println!( "current effective poll interval: {}s", state.effective_poll_secs() ); println!( "poll interval source : {}", state.poll_interval_source() ); println!("stored refresh hint : {:?}", state.refresh); println!("default poll interval : {}s", state.default_poll_secs); println!("last_error_code : {:?}", state.last_error_code); println!("auto polling paused : {}", state.poll_paused); } ["poll", "pause"] => { state.pause_auto_poll(); println!("auto polling paused"); } ["poll", "resume"] => { state.resume_auto_poll(); println!( "auto polling resumed, next poll scheduled after {}s", state.effective_poll_secs() ); } ["poll", secs] => { let secs = match secs.parse::() { Ok(v) if v > 0 => v, Ok(_) => { println!("poll interval must be > 0"); return Ok(false); } Err(err) => { println!("invalid poll seconds: {}", err); return Ok(false); } }; state.refresh = Some(secs as u32); state.schedule_next_poll(); println!("updated poll interval to {}s", secs); } ["quit"] | ["exit"] => { return Ok(true); } ["output"] => { println!("current output mode: {}", state.output_mode.describe()); println!("skipped payload PDUs: {}", state.skipped_payload_pdu_count); } ["output", "verbose"] => { state.output_mode = OutputMode::Verbose; println!("updated output mode to verbose"); } ["output", "summary"] => { state.output_mode = OutputMode::SummaryOnly; println!("updated output mode to summary"); } _ => { println!("unknown command: {}", line); print_help(); } } Ok(false) } fn print_help() { println!("available commands:"); println!(" help show this help"); println!(" state print current client state"); println!(" version show current RTR version"); println!(" version update RTR version"); println!(" reset send Reset Query"); println!(" serial send Serial Query with current session_id/serial"); println!(" serial send Serial Query with explicit values"); println!(" timeout show current read timeout"); println!(" timeout update read timeout seconds"); println!(" poll show current poll interval"); println!(" poll override poll interval seconds"); println!(" poll pause pause auto polling"); println!(" poll resume resume auto polling"); println!(" keep-after-error show current keep-after-error setting"); println!(" output show current output mode"); println!(" output verbose print all PDUs"); println!(" output summary suppress payload PDU details"); println!(" quit exit client"); println!(); } fn print_state(state: &ClientState) { println!("client state:"); println!(" version : {}", state.version); println!(" session_id : {:?}", state.session_id); println!(" serial : {:?}", state.serial); println!(" current_session_id : {:?}", state.current_session_id); println!(" refresh : {:?}", state.refresh); println!(" retry : {:?}", state.retry); println!(" expire : {:?}", state.expire); println!(" read_timeout_secs : {}", state.read_timeout_secs); println!(" default_poll_secs : {}", state.default_poll_secs); println!(" effective_poll_secs: {}", state.effective_poll_secs()); println!(" poll_source : {}", state.poll_interval_source()); println!(" last_error_code : {:?}", state.last_error_code); println!(" keep_after_error : {}", state.keep_after_error); println!(" output_mode : {}", state.output_mode.describe()); println!(" skipped_payloads : {}", state.skipped_payload_pdu_count); println!(" poll_paused : {}", state.poll_paused); println!(); } #[derive(Debug)] struct ClientState { version: u8, session_id: Option, serial: Option, current_session_id: Option, refresh: Option, retry: Option, expire: Option, last_error_code: Option, keep_after_error: bool, output_mode: OutputMode, skipped_payload_pdu_count: u64, skipped_payload_pdu_count_in_round: u64, read_timeout_secs: u64, default_poll_secs: u64, next_poll_deadline: Instant, poll_paused: bool, force_reset_on_reconnect: bool, reconnect_now: bool, } impl ClientState { fn new( version: u8, read_timeout_secs: u64, default_poll_secs: u64, keep_after_error: bool, output_mode: OutputMode, ) -> Self { Self { version, session_id: None, serial: None, current_session_id: None, refresh: None, retry: None, expire: None, last_error_code: None, keep_after_error, output_mode, skipped_payload_pdu_count: 0, skipped_payload_pdu_count_in_round: 0, read_timeout_secs, default_poll_secs, next_poll_deadline: Instant::now() + Duration::from_secs(default_poll_secs), poll_paused: false, force_reset_on_reconnect: false, reconnect_now: false, } } fn effective_poll_secs(&self) -> u64 { if self.should_prefer_retry_poll() { self.retry.map(|v| v as u64).unwrap_or_else(|| { self.refresh .map(|v| v as u64) .unwrap_or(self.default_poll_secs) }) } else { self.refresh .map(|v| v as u64) .unwrap_or(self.default_poll_secs) } } fn schedule_next_poll(&mut self) { self.next_poll_deadline = Instant::now() + Duration::from_secs(self.effective_poll_secs()); } fn pause_auto_poll(&mut self) { self.poll_paused = true; } fn resume_auto_poll(&mut self) { self.poll_paused = false; self.schedule_next_poll(); } fn poll_deadline(&self) -> Instant { if self.poll_paused { Instant::now() + Duration::from_secs(365 * 24 * 60 * 60) } else { self.next_poll_deadline } } fn should_prefer_retry_poll(&self) -> bool { matches!(self.last_error_code, Some(2 | 10)) } fn poll_interval_source(&self) -> &'static str { if self.should_prefer_retry_poll() && self.retry.is_some() { "retry" } else if self.refresh.is_some() { "refresh" } else { "default" } } fn reconnect_delay_secs(&self) -> u64 { if self.should_prefer_retry_poll() { self.retry .map(|v| v as u64) .unwrap_or(self.default_poll_secs) } else { self.default_poll_secs } } fn request_reconnect_now(&mut self) { self.reconnect_now = true; } fn take_reconnect_now(&mut self) -> bool { if self.reconnect_now { self.reconnect_now = false; true } else { false } } fn observe_pdu(&mut self, header: &PduHeader) { if self.output_mode == OutputMode::SummaryOnly && is_payload_pdu(header) { self.skipped_payload_pdu_count = self.skipped_payload_pdu_count.saturating_add(1); self.skipped_payload_pdu_count_in_round = self.skipped_payload_pdu_count_in_round.saturating_add(1); } } } #[derive(Debug)] struct Config { addr: String, version: u8, mode: QueryMode, read_timeout_secs: u64, default_poll_secs: u64, transport: TransportConfig, keep_after_error: bool, output_mode: OutputMode, } impl Config { fn from_args() -> io::Result { let mut args = env::args().skip(1); let mut positional = Vec::new(); let mut transport = TransportConfig::Tcp; let mut read_timeout_secs = DEFAULT_READ_TIMEOUT_SECS; let mut default_poll_secs = DEFAULT_POLL_INTERVAL_SECS; let mut keep_after_error = false; let mut output_mode = OutputMode::Verbose; while let Some(arg) = args.next() { match arg.as_str() { "--tls" => match transport { TransportConfig::Tcp => { transport = TransportConfig::Tls(TlsConfig::default()); } TransportConfig::Tls(_) => {} TransportConfig::Ssh(_) => { return Err(io::Error::new( io::ErrorKind::InvalidInput, "--tls cannot be used together with --ssh", )); } }, "--ssh" => match transport { TransportConfig::Tcp => { transport = TransportConfig::Ssh(SshConfig::default()); } TransportConfig::Ssh(_) => {} TransportConfig::Tls(_) => { return Err(io::Error::new( io::ErrorKind::InvalidInput, "--ssh cannot be used together with --tls", )); } }, "--ca-cert" => { let path = args.next().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "--ca-cert requires a path") })?; ensure_tls_config(&mut transport)?.ca_cert = Some(PathBuf::from(path)); } "--client-cert" => { let path = args.next().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "--client-cert requires a path") })?; ensure_tls_config(&mut transport)?.client_cert = Some(PathBuf::from(path)); } "--client-key" => { let path = args.next().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "--client-key requires a path") })?; ensure_tls_config(&mut transport)?.client_key = Some(PathBuf::from(path)); } "--server-name" => { let name = args.next().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "--server-name requires a value", ) })?; ensure_tls_config(&mut transport)?.server_name = Some(name); } "--ssh-user" => { let user = args.next().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "--ssh-user requires a value") })?; ensure_ssh_config(&mut transport)?.user = Some(user); } "--ssh-key" => { let path = args.next().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "--ssh-key requires a path") })?; ensure_ssh_config(&mut transport)?.private_key = Some(PathBuf::from(path)); } "--ssh-subsystem" => { let subsystem = args.next().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "--ssh-subsystem requires a value", ) })?; ensure_ssh_config(&mut transport)?.subsystem = Some(subsystem); } "--ssh-known-hosts" => { let path = args.next().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "--ssh-known-hosts requires a path", ) })?; ensure_ssh_config(&mut transport)?.known_hosts = Some(PathBuf::from(path)); } "--ssh-server-key" => { let path = args.next().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "--ssh-server-key requires a path", ) })?; ensure_ssh_config(&mut transport)?.server_key = Some(PathBuf::from(path)); } "--timeout" => { let secs = args.next().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "--timeout requires seconds") })?; read_timeout_secs = parse_u64_arg(&secs, "--timeout")?; } "--poll" => { let secs = args.next().ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "--poll requires seconds") })?; default_poll_secs = parse_u64_arg(&secs, "--poll")?; } "--keep-after-error" => { keep_after_error = true; } "--summary-only" => { output_mode = OutputMode::SummaryOnly; } _ if arg.starts_with("--") => { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!("unknown option '{}'", arg), )); } _ => positional.push(arg), } } let mut positional = positional.into_iter(); let addr = positional .next() .unwrap_or_else(|| "127.0.0.1:323".to_string()); let version = positional .next() .map(|s| { s.parse::().map_err(|e| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid version '{}': {}", s, e), ) }) }) .transpose()? .unwrap_or(1); // Allow any version here; server will validate and respond. let mode = match positional.next().as_deref() { None | Some("reset") => QueryMode::Reset, Some("serial") => { let session_id = positional .next() .ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "serial mode requires session_id and serial", ) })? .parse::() .map_err(|e| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid session_id: {}", e), ) })?; let serial = positional .next() .ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "serial mode requires serial") })? .parse::() .map_err(|e| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid serial: {}", e), ) })?; QueryMode::Serial { session_id, serial } } Some(other) => { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!("invalid mode '{}', expected 'reset' or 'serial'", other), )); } }; let transport = finalize_transport(transport, &addr)?; Ok(Self { addr, version, mode, read_timeout_secs, default_poll_secs, transport, keep_after_error, output_mode, }) } } impl OutputMode { fn describe(self) -> &'static str { match self { Self::Verbose => "verbose", Self::SummaryOnly => "summary", } } } fn is_payload_pdu(header: &PduHeader) -> bool { matches!( header.pdu_type(), PduType::Ipv4Prefix | PduType::Ipv6Prefix | PduType::RouterKey | PduType::Aspa ) } fn should_print_pdu(output_mode: OutputMode, header: &PduHeader) -> bool { match output_mode { OutputMode::Verbose => true, OutputMode::SummaryOnly => !is_payload_pdu(header), } } #[derive(Debug, Clone)] enum TransportConfig { Tcp, Tls(TlsConfig), Ssh(SshConfig), } impl TransportConfig { fn describe(&self) -> String { match self { Self::Tcp => "tcp".to_string(), Self::Tls(cfg) => format!( "tls (server_name={}, ca_cert={}, client_cert={})", cfg.server_name.as_deref().unwrap_or(""), cfg.ca_cert .as_ref() .map(|path| path.display().to_string()) .unwrap_or_else(|| "".to_string()), cfg.client_cert .as_ref() .map(|path| path.display().to_string()) .unwrap_or_else(|| "".to_string()) ), Self::Ssh(cfg) => format!( "ssh (user={}, subsystem={}, host_key_check={})", cfg.user.as_deref().unwrap_or(""), cfg.subsystem .as_deref() .unwrap_or(DEFAULT_SSH_SUBSYSTEM_NAME), cfg.host_key_verification .as_ref() .map(HostKeyVerification::describe) .unwrap_or("") ), } } } #[derive(Debug, Clone, Default)] struct TlsConfig { server_name: Option, ca_cert: Option, client_cert: Option, client_key: Option, } #[derive(Debug, Clone)] enum HostKeyVerification { KnownHosts(PathBuf), PinnedServerKey(PathBuf), } impl HostKeyVerification { fn describe(&self) -> &'static str { match self { Self::KnownHosts(_) => "known_hosts", Self::PinnedServerKey(_) => "pinned_server_key", } } } #[derive(Debug, Clone, Default)] struct SshConfig { user: Option, private_key: Option, subsystem: Option, known_hosts: Option, server_key: Option, host_key_verification: Option, } fn ensure_tls_config(transport: &mut TransportConfig) -> io::Result<&mut TlsConfig> { match transport { TransportConfig::Tcp => { *transport = TransportConfig::Tls(TlsConfig::default()); match transport { TransportConfig::Tls(cfg) => Ok(cfg), _ => unreachable!(), } } TransportConfig::Tls(cfg) => Ok(cfg), TransportConfig::Ssh(_) => Err(io::Error::new( io::ErrorKind::InvalidInput, "TLS options cannot be used together with --ssh", )), } } fn ensure_ssh_config(transport: &mut TransportConfig) -> io::Result<&mut SshConfig> { match transport { TransportConfig::Tcp => { *transport = TransportConfig::Ssh(SshConfig::default()); match transport { TransportConfig::Ssh(cfg) => Ok(cfg), _ => unreachable!(), } } TransportConfig::Ssh(cfg) => Ok(cfg), TransportConfig::Tls(_) => Err(io::Error::new( io::ErrorKind::InvalidInput, "SSH options cannot be used together with --tls", )), } } fn finalize_transport(transport: TransportConfig, addr: &str) -> io::Result { match transport { TransportConfig::Tcp => Ok(TransportConfig::Tcp), TransportConfig::Tls(mut cfg) => { let ca_cert = cfg.ca_cert.take().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "TLS mode requires --ca-cert ", ) })?; match (&cfg.client_cert, &cfg.client_key) { (Some(_), Some(_)) | (None, None) => {} _ => { return Err(io::Error::new( io::ErrorKind::InvalidInput, "TLS client authentication requires both --client-cert and --client-key", )); } } let server_name = cfg .server_name .take() .or_else(|| default_server_name_for_addr(addr)) .ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "TLS mode requires --server-name or an address with a parsable host", ) })?; Ok(TransportConfig::Tls(TlsConfig { server_name: Some(server_name), ca_cert: Some(ca_cert), client_cert: cfg.client_cert, client_key: cfg.client_key, })) } TransportConfig::Ssh(mut cfg) => { let user = cfg.user.take().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "SSH mode requires --ssh-user ", ) })?; if user.trim().is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "--ssh-user must not be empty", )); } let private_key = cfg.private_key.take().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "SSH mode requires --ssh-key ", ) })?; if cfg.known_hosts.is_some() && cfg.server_key.is_some() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "SSH host key verification must choose one: --ssh-known-hosts or --ssh-server-key", )); } let host_key_verification = if let Some(path) = cfg.known_hosts.take() { HostKeyVerification::KnownHosts(path) } else if let Some(path) = cfg.server_key.take() { HostKeyVerification::PinnedServerKey(path) } else { return Err(io::Error::new( io::ErrorKind::InvalidInput, "SSH mode requires host key verification: --ssh-known-hosts or --ssh-server-key ", )); }; let subsystem = cfg .subsystem .take() .unwrap_or_else(|| DEFAULT_SSH_SUBSYSTEM_NAME.to_string()); if subsystem.trim().is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "--ssh-subsystem must not be empty", )); } let _ = parse_host_port(addr)?; Ok(TransportConfig::Ssh(SshConfig { user: Some(user), private_key: Some(private_key), subsystem: Some(subsystem), known_hosts: None, server_key: None, host_key_verification: Some(host_key_verification), })) } } } async fn connect_stream(config: &Config) -> io::Result { match &config.transport { TransportConfig::Tcp => Ok(Box::new(TcpStream::connect(&config.addr).await?)), TransportConfig::Tls(tls) => connect_tls_stream(&config.addr, tls).await, TransportConfig::Ssh(ssh) => connect_ssh_stream(&config.addr, ssh).await, } } #[derive(Debug, Clone)] struct SshClientHandler { host: String, port: u16, host_key_verification: HostKeyVerification, } impl client::Handler for SshClientHandler { type Error = russh::Error; async fn check_server_key( &mut self, server_public_key: &russh::keys::ssh_key::PublicKey, ) -> Result { match &self.host_key_verification { HostKeyVerification::KnownHosts(path) => { check_known_hosts_path(&self.host, self.port, server_public_key, path) .map_err(Into::into) } HostKeyVerification::PinnedServerKey(path) => { let expected_key = load_public_key(path)?; Ok(expected_key == *server_public_key) } } } } struct SshSessionStream { channel_stream: ChannelStream, _session: client::Handle, } impl AsyncRead for SshSessionStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { Pin::new(&mut self.channel_stream).poll_read(cx, buf) } } impl AsyncWrite for SshSessionStream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut self.channel_stream).poll_write(cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.channel_stream).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.channel_stream).poll_shutdown(cx) } } async fn connect_ssh_stream(addr: &str, ssh: &SshConfig) -> io::Result { let user = ssh .user .as_deref() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing SSH user"))?; let private_key_path = ssh .private_key .as_ref() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing SSH private key"))?; let subsystem = ssh .subsystem .as_deref() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing SSH subsystem"))?; let host_key_verification = ssh.host_key_verification.clone().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "missing SSH host key verification", ) })?; let (host, port) = parse_host_port(addr)?; let handler = SshClientHandler { host, port, host_key_verification, }; let session_config = Arc::new(client::Config::default()); let mut session = client::connect(session_config, addr, handler) .await .map_err(|err| { io::Error::new( io::ErrorKind::ConnectionAborted, format!("SSH handshake failed: {}", err), ) })?; let private_key = load_secret_key(private_key_path, None).map_err(|err| { io::Error::new( io::ErrorKind::InvalidInput, format!( "failed to load SSH private key {}: {}", private_key_path.display(), err ), ) })?; let rsa_hash = session.best_supported_rsa_hash().await.map_err(|err| { io::Error::new( io::ErrorKind::ConnectionAborted, format!("failed to negotiate SSH RSA hash: {}", err), ) })?; let auth_result = session .authenticate_publickey( user.to_string(), PrivateKeyWithHashAlg::new(Arc::new(private_key), rsa_hash.flatten()), ) .await .map_err(|err| { io::Error::new( io::ErrorKind::PermissionDenied, format!("SSH publickey authentication failed: {}", err), ) })?; if !auth_result.success() { return Err(io::Error::new( io::ErrorKind::PermissionDenied, "SSH publickey authentication rejected by server", )); } let channel = session.channel_open_session().await.map_err(|err| { io::Error::new( io::ErrorKind::ConnectionAborted, format!("failed to open SSH session channel: {}", err), ) })?; channel .request_subsystem(true, subsystem) .await .map_err(|err| { io::Error::new( io::ErrorKind::ConnectionAborted, format!("failed to request SSH subsystem '{}': {}", subsystem, err), ) })?; let channel_stream = channel.into_stream(); Ok(Box::new(SshSessionStream { channel_stream, _session: session, })) } async fn connect_tls_stream(addr: &str, tls: &TlsConfig) -> io::Result { let stream = TcpStream::connect(addr).await?; let connector = build_tls_connector(tls)?; let server_name_str = tls .server_name .as_ref() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS server name"))?; let server_name = ServerName::try_from(server_name_str.clone()).map_err(|err| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid TLS server name '{}': {}", server_name_str, err), ) })?; let tls_stream = connector .connect(server_name, stream) .await .map_err(|err| { io::Error::new( io::ErrorKind::ConnectionAborted, format!("TLS handshake failed: {}", err), ) })?; Ok(Box::new(tls_stream)) } fn build_tls_connector(tls: &TlsConfig) -> io::Result { let ca_cert_path = tls .ca_cert .as_ref() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing TLS CA cert"))?; let ca_certs = load_certs(ca_cert_path)?; let mut roots = RootCertStore::empty(); let (added, _ignored) = roots.add_parsable_certificates(ca_certs); if added == 0 { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!( "no valid CA certificates found in {}", ca_cert_path.display() ), )); } let builder = RustlsClientConfig::builder().with_root_certificates(roots); let client_config = match (&tls.client_cert, &tls.client_key) { (Some(cert_path), Some(key_path)) => { let certs = load_certs(cert_path)?; let key = load_private_key(key_path)?; builder.with_client_auth_cert(certs, key).map_err(|err| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid TLS client certificate/key: {}", err), ) })? } (None, None) => builder.with_no_client_auth(), _ => unreachable!(), }; Ok(TlsConnector::from(Arc::new(client_config))) } fn load_certs(path: &Path) -> io::Result>> { let mut reader = std::io::BufReader::new(std::fs::File::open(path)?); let certs = rustls_pemfile::certs(&mut reader) .collect::, _>>() .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; if certs.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidData, format!("no certificates found in {}", path.display()), )); } Ok(certs) } fn load_private_key(path: &Path) -> io::Result> { let mut reader = std::io::BufReader::new(std::fs::File::open(path)?); rustls_pemfile::private_key(&mut reader) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? .ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidData, format!("no private key found in {}", path.display()), ) }) } fn default_server_name_for_addr(addr: &str) -> Option { if let Some(rest) = addr.strip_prefix('[') { return rest.split(']').next().map(str::to_string); } addr.rsplit_once(':').map(|(host, _port)| host.to_string()) } fn parse_host_port(addr: &str) -> io::Result<(String, u16)> { if let Some(rest) = addr.strip_prefix('[') { let (host, port_part) = rest.split_once("]:").ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid address '{}', expected [host]:port", addr), ) })?; let port = port_part.parse::().map_err(|err| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid port in address '{}': {}", addr, err), ) })?; return Ok((host.to_string(), port)); } let (host, port_part) = addr.rsplit_once(':').ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid address '{}', expected host:port", addr), ) })?; if host.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!("invalid address '{}', host must not be empty", addr), )); } let port = port_part.parse::().map_err(|err| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid port in address '{}': {}", addr, err), ) })?; Ok((host.to_string(), port)) } fn parse_u64_arg(value: &str, name: &str) -> io::Result { let parsed = value.parse::().map_err(|err| { io::Error::new( io::ErrorKind::InvalidInput, format!("invalid value for {}: {}", name, err), ) })?; if parsed == 0 { return Err(io::Error::new( io::ErrorKind::InvalidInput, format!("{} must be > 0", name), )); } Ok(parsed) }