use base64::Engine; use quick_xml::Reader; use quick_xml::events::Event; use sha2::Digest; use std::io::{BufRead, Seek, SeekFrom, Write}; use uuid::Uuid; use crate::current_repo_index::CurrentRepoIndexHandle; use crate::storage::RocksStore; use crate::sync::store_projection::{ build_repository_view_present_entry, build_repository_view_withdrawn_entry, build_rrdp_source_member_present_record, build_rrdp_source_member_withdrawn_record, build_rrdp_uri_owner_active_record, build_rrdp_uri_owner_withdrawn_record, compute_sha256_hex, current_rrdp_owner_is, ensure_rrdp_uri_can_be_owned_by, prepare_repo_bytes_batch, }; use super::{ Fetcher, RRDP_SNAPSHOT_APPLY_BATCH_SIZE, RRDP_XMLNS, RrdpError, RrdpSyncError, parse_u64_str, strip_all_ascii_whitespace, }; #[cfg(test)] pub(super) fn apply_snapshot( store: &RocksStore, notification_uri: &str, current_repo_index: Option<&CurrentRepoIndexHandle>, snapshot_xml: &[u8], expected_session_id: Uuid, expected_serial: u64, ) -> Result { if snapshot_xml.iter().any(|&b| b > 0x7F) { return Err(RrdpError::NotAscii.into()); } apply_snapshot_from_bufread( store, notification_uri, current_repo_index, std::io::Cursor::new(snapshot_xml), expected_session_id, expected_serial, ) } pub(super) fn apply_snapshot_from_bufread( store: &RocksStore, notification_uri: &str, current_repo_index: Option<&CurrentRepoIndexHandle>, input: R, expected_session_id: Uuid, expected_serial: u64, ) -> Result { let previous_members: Vec = store .list_current_rrdp_source_members(notification_uri) .map_err(|e| RrdpSyncError::Storage(e.to_string()))? .into_iter() .map(|record| record.rsync_uri) .collect(); let session_id = expected_session_id.to_string(); let mut new_set: std::collections::HashSet = std::collections::HashSet::new(); let mut batch_published: Vec<(String, Vec)> = Vec::with_capacity(RRDP_SNAPSHOT_APPLY_BATCH_SIZE); let mut published_count = 0usize; let mut reader = Reader::from_reader(input); reader.config_mut().trim_text(false); let mut buf = Vec::new(); let mut root_seen = false; let mut in_publish = false; let mut publish_nested_depth = 0usize; let mut current_publish_uri: Option = None; let mut current_publish_text = String::new(); loop { match reader.read_event_into(&mut buf) { Ok(Event::Start(e)) => { let local_name = e.local_name(); let local_name = local_name.as_ref(); if !root_seen { root_seen = true; if local_name != b"snapshot" { let got = String::from_utf8_lossy(local_name).to_string(); return Err(RrdpError::UnexpectedRoot(got).into()); } let mut xmlns = String::new(); let mut version = String::new(); let mut session_id_attr = String::new(); let mut serial_attr = String::new(); for attr in e.attributes().with_checks(false) { let attr = match attr { Ok(attr) => attr, Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), }; let key = attr.key.as_ref(); let value = attr .decode_and_unescape_value(reader.decoder()) .map_err(|e| RrdpError::Xml(e.to_string()))? .into_owned(); match key { b"xmlns" => xmlns = value, b"version" => version = value, b"session_id" => session_id_attr = value, b"serial" => serial_attr = value, _ => {} } } if xmlns != RRDP_XMLNS { return Err(RrdpError::InvalidNamespace(xmlns).into()); } if version != "1" { return Err(RrdpError::InvalidVersion(version).into()); } let got_session_id = Uuid::parse_str(&session_id_attr) .map_err(|_| RrdpError::InvalidSessionId(session_id_attr.clone()))?; if got_session_id != expected_session_id { return Err(RrdpError::SnapshotSessionIdMismatch { expected: expected_session_id.to_string(), got: got_session_id.to_string(), } .into()); } let got_serial = parse_u64_str(&serial_attr)?; if got_serial != expected_serial { return Err(RrdpError::SnapshotSerialMismatch { expected: expected_serial, got: got_serial, } .into()); } } else if in_publish { publish_nested_depth += 1; } else if local_name == b"publish" { let mut uri = None; for attr in e.attributes().with_checks(false) { let attr = match attr { Ok(attr) => attr, Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), }; if attr.key.as_ref() == b"uri" { uri = Some( attr.decode_and_unescape_value(reader.decoder()) .map_err(|e| RrdpError::Xml(e.to_string()))? .into_owned(), ); } } let uri = uri.ok_or(RrdpError::PublishUriMissing)?; ensure_rrdp_uri_can_be_owned_by(store, notification_uri, &uri) .map_err(RrdpSyncError::Storage)?; in_publish = true; publish_nested_depth = 0; current_publish_uri = Some(uri); current_publish_text.clear(); } } Ok(Event::Empty(e)) => { let local_name = e.local_name(); let local_name = local_name.as_ref(); if !root_seen { let got = String::from_utf8_lossy(local_name).to_string(); return Err(RrdpError::UnexpectedRoot(got).into()); } if local_name == b"publish" { let mut has_uri = false; for attr in e.attributes().with_checks(false) { let attr = match attr { Ok(attr) => attr, Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), }; if attr.key.as_ref() == b"uri" { has_uri = true; break; } } if !has_uri { return Err(RrdpError::PublishUriMissing.into()); } return Err(RrdpError::PublishContentMissing.into()); } } Ok(Event::Text(e)) => { if in_publish && publish_nested_depth == 0 { let text = reader .decoder() .decode(e.as_ref()) .map_err(|e| RrdpError::Xml(e.to_string()))?; current_publish_text.push_str(&text); } } Ok(Event::CData(e)) => { if in_publish && publish_nested_depth == 0 { let text = reader .decoder() .decode(e.as_ref()) .map_err(|e| RrdpError::Xml(e.to_string()))?; current_publish_text.push_str(&text); } } Ok(Event::End(e)) => { let local_name = e.local_name(); let local_name = local_name.as_ref(); if in_publish { if publish_nested_depth > 0 { publish_nested_depth -= 1; } else if local_name == b"publish" { let uri = current_publish_uri .take() .ok_or_else(|| RrdpError::Xml("publish uri missing in state".into()))?; let content_b64 = strip_all_ascii_whitespace(¤t_publish_text); current_publish_text.clear(); if content_b64.is_empty() { return Err(RrdpError::PublishContentMissing.into()); } let bytes = base64::engine::general_purpose::STANDARD .decode(content_b64.as_bytes()) .map_err(|e| RrdpError::PublishBase64(e.to_string()))?; new_set.insert(uri.clone()); batch_published.push((uri, bytes)); published_count += 1; if batch_published.len() >= RRDP_SNAPSHOT_APPLY_BATCH_SIZE { flush_snapshot_publish_batch( store, notification_uri, current_repo_index, &session_id, expected_serial, &batch_published, )?; batch_published.clear(); } in_publish = false; } } } Ok(Event::Eof) => break, Ok(Event::Decl(_) | Event::PI(_) | Event::Comment(_) | Event::DocType(_)) => {} Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), } buf.clear(); } if !root_seen { return Err(RrdpError::Xml("missing root element".to_string()).into()); } if in_publish { return Err(RrdpError::PublishContentMissing.into()); } if !batch_published.is_empty() { flush_snapshot_publish_batch( store, notification_uri, current_repo_index, &session_id, expected_serial, &batch_published, )?; batch_published.clear(); } let mut withdrawn: Vec<(String, Option)> = Vec::new(); for old_uri in &previous_members { if new_set.contains(old_uri) { continue; } let previous_hash = store .get_repository_view_entry(old_uri) .map_err(|e| RrdpSyncError::Storage(e.to_string()))? .and_then(|entry| entry.current_hash) .or_else(|| { store .load_current_object_bytes_by_uri(old_uri) .ok() .flatten() .map(|bytes| compute_sha256_hex(&bytes)) }); withdrawn.push((old_uri.clone(), previous_hash)); } let mut repository_view_entries = Vec::with_capacity(withdrawn.len()); let mut member_records = Vec::with_capacity(withdrawn.len()); let mut owner_records = Vec::with_capacity(withdrawn.len()); for (uri, previous_hash) in withdrawn { member_records.push(build_rrdp_source_member_withdrawn_record( notification_uri, &session_id, expected_serial, &uri, previous_hash.clone(), )); if current_rrdp_owner_is(store, notification_uri, &uri).map_err(RrdpSyncError::Storage)? { repository_view_entries.push(build_repository_view_withdrawn_entry( notification_uri, &uri, previous_hash.clone(), )); owner_records.push(build_rrdp_uri_owner_withdrawn_record( notification_uri, &session_id, expected_serial, &uri, previous_hash, )); } } store .put_projection_batch(&repository_view_entries, &member_records, &owner_records) .map_err(|e| RrdpSyncError::Storage(e.to_string()))?; if let Some(index) = current_repo_index { index .lock() .map_err(|_| RrdpSyncError::Storage("current repo index lock poisoned".to_string()))? .apply_repository_view_entries(&repository_view_entries) .map_err(RrdpSyncError::Storage)?; } Ok(published_count) } fn flush_snapshot_publish_batch( store: &RocksStore, notification_uri: &str, current_repo_index: Option<&CurrentRepoIndexHandle>, session_id: &str, serial: u64, published: &[(String, Vec)], ) -> Result<(), RrdpSyncError> { let prepared_bytes = prepare_repo_bytes_batch(published).map_err(RrdpSyncError::Storage)?; let mut repository_view_entries = Vec::with_capacity(published.len()); let mut member_records = Vec::with_capacity(published.len()); let mut owner_records = Vec::with_capacity(published.len()); for (uri, _bytes) in published { let current_hash = prepared_bytes .uri_to_hash .get(uri) .cloned() .ok_or_else(|| { RrdpSyncError::Storage(format!("missing raw_by_hash mapping for {uri}")) })?; repository_view_entries.push(build_repository_view_present_entry( notification_uri, uri, ¤t_hash, )); member_records.push(build_rrdp_source_member_present_record( notification_uri, session_id, serial, uri, ¤t_hash, )); owner_records.push(build_rrdp_uri_owner_active_record( notification_uri, session_id, serial, uri, ¤t_hash, )); } store .put_blob_bytes_batch(&prepared_bytes.blobs_to_write) .map_err(|e| RrdpSyncError::Storage(e.to_string()))?; store .put_projection_batch(&repository_view_entries, &member_records, &owner_records) .map_err(|e| RrdpSyncError::Storage(e.to_string()))?; if let Some(index) = current_repo_index { index .lock() .map_err(|_| RrdpSyncError::Storage("current repo index lock poisoned".to_string()))? .apply_repository_view_entries(&repository_view_entries) .map_err(RrdpSyncError::Storage)?; } Ok(()) } const SNAPSHOT_NON_ASCII_ERROR: &str = "snapshot body contains non-ASCII bytes"; struct SnapshotSpoolWriter<'a, W: Write> { inner: &'a mut W, hasher: sha2::Sha256, bytes: u64, } impl<'a, W: Write> SnapshotSpoolWriter<'a, W> { fn new(inner: &'a mut W) -> Self { Self { inner, hasher: sha2::Sha256::new(), bytes: 0, } } fn finalize_hash(self) -> [u8; 32] { let digest = self.hasher.finalize(); let mut out = [0u8; 32]; out.copy_from_slice(&digest); out } } impl Write for SnapshotSpoolWriter<'_, W> { fn write(&mut self, buf: &[u8]) -> std::io::Result { if buf.iter().any(|&b| b > 0x7F) { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, SNAPSHOT_NON_ASCII_ERROR, )); } let n = self.inner.write(buf)?; self.hasher.update(&buf[..n]); self.bytes += n as u64; Ok(n) } fn flush(&mut self) -> std::io::Result<()> { self.inner.flush() } } pub(super) fn fetch_snapshot_into_tempfile( fetcher: &dyn Fetcher, snapshot_uri: &str, expected_hash_sha256: &[u8; 32], ) -> Result<(tempfile::NamedTempFile, u64), RrdpSyncError> { let mut tmp = tempfile::NamedTempFile::new() .map_err(|e| RrdpSyncError::Fetch(format!("tempfile create failed: {e}")))?; let mut spool = SnapshotSpoolWriter::new(tmp.as_file_mut()); let bytes_written = match fetcher.fetch_to_writer(snapshot_uri, &mut spool) { Ok(bytes) => bytes, Err(e) if e.contains(SNAPSHOT_NON_ASCII_ERROR) => return Err(RrdpError::NotAscii.into()), Err(e) => return Err(RrdpSyncError::Fetch(e)), }; let computed = spool.finalize_hash(); if computed.as_slice() != expected_hash_sha256.as_slice() { return Err(RrdpError::SnapshotHashMismatch.into()); } tmp.as_file_mut() .flush() .map_err(|e| RrdpSyncError::Fetch(format!("tempfile flush failed: {e}")))?; tmp.as_file_mut() .seek(SeekFrom::Start(0)) .map_err(|e| RrdpSyncError::Fetch(format!("tempfile rewind failed: {e}")))?; Ok((tmp, bytes_written)) }