From e45830d79f058eef4c687b35dc0f5550ccca3c8b Mon Sep 17 00:00:00 2001 From: yuyr Date: Sat, 11 Apr 2026 14:45:08 +0800 Subject: [PATCH] =?UTF-8?q?20260411=20apply=20snapahot=E5=86=85=E5=AD=98?= =?UTF-8?q?=E4=BC=98=E5=8C=96=EF=BC=8C=E9=87=87=E7=94=A8=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E5=86=99=E6=96=87=E4=BB=B6=E5=92=8C=E5=88=86=E5=9D=97=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=99=8D=E4=BD=8E=E8=BF=90=E8=A1=8C=E5=86=85=E5=AD=98?= =?UTF-8?q?=E9=9C=80=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.toml | 3 +- src/fetch/http.rs | 111 ++++++++++ src/sync/rrdp.rs | 526 ++++++++++++++++++++++++++++++++++++---------- 3 files changed, 523 insertions(+), 117 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fbaa061..620bd39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,10 +26,11 @@ toml = "0.8.20" rocksdb = { version = "0.22.0", optional = true, default-features = false, features = ["lz4"] } serde_cbor = "0.11.2" roxmltree = "0.20.0" +quick-xml = "0.37.2" uuid = { version = "1.7.0", features = ["v4"] } reqwest = { version = "0.12.12", default-features = false, features = ["blocking", "rustls-tls"] } pprof = { version = "0.14.1", optional = true, features = ["flamegraph", "prost-codec"] } flate2 = { version = "1.0.35", optional = true } +tempfile = "3.16.0" [dev-dependencies] -tempfile = "3.16.0" diff --git a/src/fetch/http.rs b/src/fetch/http.rs index bc7290d..0a84034 100644 --- a/src/fetch/http.rs +++ b/src/fetch/http.rs @@ -1,3 +1,4 @@ +use std::io::Write; use std::time::Duration; use reqwest::blocking::Client; @@ -191,6 +192,116 @@ impl Fetcher for BlockingHttpFetcher { fn fetch(&self, uri: &str) -> Result, String> { self.fetch_bytes(uri) } + + fn fetch_to_writer(&self, uri: &str, out: &mut dyn Write) -> Result { + let started = std::time::Instant::now(); + let (client, timeout_profile, timeout_value) = self.client_for_uri(uri); + let resp = client.get(uri).send().map_err(|e| { + let msg = format!("http request failed: {e}"); + crate::progress_log::emit( + "http_fetch_failed", + serde_json::json!({ + "uri": uri, + "stage": "request", + "timeout_profile": timeout_profile, + "request_timeout_ms": timeout_value.as_millis() as u64, + "duration_ms": started.elapsed().as_millis() as u64, + "error": msg, + }), + ); + msg + })?; + + let status = resp.status(); + let headers = resp.headers().clone(); + if !status.is_success() { + let body_preview = resp + .text() + .ok() + .map(|text| text.chars().take(160).collect::()); + let body_prefix = body_preview + .clone() + .unwrap_or_else(|| "".to_string()); + let msg = format!( + "http status {status}; content_type={}; content_encoding={}; content_length={}; transfer_encoding={}; body_prefix={}", + header_value(&headers, "content-type"), + header_value(&headers, "content-encoding"), + header_value(&headers, "content-length"), + header_value(&headers, "transfer-encoding"), + body_prefix, + ); + crate::progress_log::emit( + "http_fetch_failed", + serde_json::json!({ + "uri": uri, + "stage": "status", + "timeout_profile": timeout_profile, + "request_timeout_ms": timeout_value.as_millis() as u64, + "duration_ms": started.elapsed().as_millis() as u64, + "status": status.as_u16(), + "content_type": header_value_opt(&headers, "content-type"), + "content_encoding": header_value_opt(&headers, "content-encoding"), + "content_length": header_value_opt(&headers, "content-length"), + "transfer_encoding": header_value_opt(&headers, "transfer-encoding"), + "body_prefix": body_preview, + "error": msg, + }), + ); + return Err(msg); + } + + let mut resp = resp; + match resp.copy_to(out) { + Ok(bytes) => { + let duration_ms = started.elapsed().as_millis() as u64; + if (duration_ms as f64) / 1000.0 >= crate::progress_log::slow_threshold_secs() { + crate::progress_log::emit( + "http_fetch_slow", + serde_json::json!({ + "uri": uri, + "status": status.as_u16(), + "timeout_profile": timeout_profile, + "request_timeout_ms": timeout_value.as_millis() as u64, + "duration_ms": duration_ms, + "bytes": bytes, + "content_type": header_value_opt(&headers, "content-type"), + "content_encoding": header_value_opt(&headers, "content-encoding"), + "content_length": header_value_opt(&headers, "content-length"), + "transfer_encoding": header_value_opt(&headers, "transfer-encoding"), + }), + ); + } + Ok(bytes) + } + Err(e) => { + let msg = format!( + "http stream body failed: {e}; status={}; content_type={}; content_encoding={}; content_length={}; transfer_encoding={}", + status, + header_value(&headers, "content-type"), + header_value(&headers, "content-encoding"), + header_value(&headers, "content-length"), + header_value(&headers, "transfer-encoding"), + ); + crate::progress_log::emit( + "http_fetch_failed", + serde_json::json!({ + "uri": uri, + "stage": "stream_body", + "timeout_profile": timeout_profile, + "request_timeout_ms": timeout_value.as_millis() as u64, + "duration_ms": started.elapsed().as_millis() as u64, + "status": status.as_u16(), + "content_type": header_value_opt(&headers, "content-type"), + "content_encoding": header_value_opt(&headers, "content-encoding"), + "content_length": header_value_opt(&headers, "content-length"), + "transfer_encoding": header_value_opt(&headers, "transfer-encoding"), + "error": msg, + }), + ); + Err(msg) + } + } + } } fn header_value(headers: &HeaderMap, name: &str) -> String { diff --git a/src/sync/rrdp.rs b/src/sync/rrdp.rs index ccc8459..fc23cc0 100644 --- a/src/sync/rrdp.rs +++ b/src/sync/rrdp.rs @@ -12,11 +12,15 @@ use crate::sync::store_projection::{ update_rrdp_source_record_on_success, upsert_raw_by_hash_evidence, }; use base64::Engine; +use quick_xml::events::Event; +use quick_xml::Reader; use serde::{Deserialize, Serialize}; use sha2::Digest; +use std::io::{BufRead, Seek, SeekFrom, Write}; use uuid::Uuid; const RRDP_XMLNS: &str = "http://www.ripe.net/rpki/rrdp"; +const RRDP_SNAPSHOT_APPLY_BATCH_SIZE: usize = 1024; #[derive(Debug, thiserror::Error)] pub enum RrdpError { @@ -190,6 +194,13 @@ pub type RrdpSyncResult = Result; pub trait Fetcher { fn fetch(&self, uri: &str) -> Result, String>; + + fn fetch_to_writer(&self, uri: &str, out: &mut dyn Write) -> Result { + let bytes = self.fetch(uri)?; + out.write_all(&bytes) + .map_err(|e| format!("write sink failed: {e}"))?; + Ok(bytes.len() as u64) + } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -548,19 +559,20 @@ fn sync_from_notification_snapshot_inner( .map(|t| t.span_phase("rrdp_fetch_snapshot_total")); let mut dl_span = download_log .map(|dl| dl.span_download(AuditDownloadKind::RrdpSnapshot, ¬if.snapshot_uri)); - let snapshot_xml = match fetcher.fetch(¬if.snapshot_uri) { + let (snapshot_file, _snapshot_bytes) = + match fetch_snapshot_into_tempfile(fetcher, ¬if.snapshot_uri, ¬if.snapshot_hash_sha256) { Ok(v) => { if let Some(t) = timing.as_ref() { t.record_count("rrdp_snapshot_fetch_ok_total", 1); - t.record_count("rrdp_snapshot_bytes_total", v.len() as u64); + t.record_count("rrdp_snapshot_bytes_total", v.1); } if let Some(s) = dl_span.as_mut() { - s.set_bytes(v.len() as u64); + s.set_bytes(v.1); s.set_ok(); } v } - Err(e) => { + Err(RrdpSyncError::Fetch(e)) => { if let Some(t) = timing.as_ref() { t.record_count("rrdp_snapshot_fetch_fail_total", 1); } @@ -569,33 +581,25 @@ fn sync_from_notification_snapshot_inner( } return Err(RrdpSyncError::Fetch(e)); } + Err(e) => return Err(e), }; drop(_fetch_step); drop(_fetch_total); - let _hash_step = timing - .as_ref() - .map(|t| t.span_rrdp_repo_step(notification_uri, "hash_snapshot")); - let _hash_total = timing - .as_ref() - .map(|t| t.span_phase("rrdp_hash_snapshot_total")); - let computed = sha2::Sha256::digest(&snapshot_xml); - if computed.as_slice() != notif.snapshot_hash_sha256.as_slice() { - return Err(RrdpError::SnapshotHashMismatch.into()); - } - drop(_hash_step); - drop(_hash_total); - let _apply_step = timing .as_ref() .map(|t| t.span_rrdp_repo_step(notification_uri, "apply_snapshot")); let _apply_total = timing .as_ref() .map(|t| t.span_phase("rrdp_apply_snapshot_total")); - let published = apply_snapshot( + let published = apply_snapshot_from_bufread( store, notification_uri, - &snapshot_xml, + std::io::BufReader::new( + snapshot_file + .reopen() + .map_err(|e| RrdpSyncError::Fetch(format!("tempfile reopen failed: {e}")))?, + ), notif.session_id, notif.serial, )?; @@ -866,19 +870,20 @@ fn sync_from_notification_inner( .map(|t| t.span_phase("rrdp_fetch_snapshot_total")); let mut dl_span = download_log .map(|dl| dl.span_download(AuditDownloadKind::RrdpSnapshot, ¬if.snapshot_uri)); - let snapshot_xml = match fetcher.fetch(¬if.snapshot_uri) { + let (snapshot_file, _snapshot_bytes) = + match fetch_snapshot_into_tempfile(fetcher, ¬if.snapshot_uri, ¬if.snapshot_hash_sha256) { Ok(v) => { if let Some(t) = timing.as_ref() { t.record_count("rrdp_snapshot_fetch_ok_total", 1); - t.record_count("rrdp_snapshot_bytes_total", v.len() as u64); + t.record_count("rrdp_snapshot_bytes_total", v.1); } if let Some(s) = dl_span.as_mut() { - s.set_bytes(v.len() as u64); + s.set_bytes(v.1); s.set_ok(); } v } - Err(e) => { + Err(RrdpSyncError::Fetch(e)) => { if let Some(t) = timing.as_ref() { t.record_count("rrdp_snapshot_fetch_fail_total", 1); } @@ -887,33 +892,25 @@ fn sync_from_notification_inner( } return Err(RrdpSyncError::Fetch(e)); } + Err(e) => return Err(e), }; drop(_fetch_step); drop(_fetch_total); - let _hash_step = timing - .as_ref() - .map(|t| t.span_rrdp_repo_step(notification_uri, "hash_snapshot")); - let _hash_total = timing - .as_ref() - .map(|t| t.span_phase("rrdp_hash_snapshot_total")); - let computed = sha2::Sha256::digest(&snapshot_xml); - if computed.as_slice() != notif.snapshot_hash_sha256.as_slice() { - return Err(RrdpError::SnapshotHashMismatch.into()); - } - drop(_hash_step); - drop(_hash_total); - let _apply_step = timing .as_ref() .map(|t| t.span_rrdp_repo_step(notification_uri, "apply_snapshot")); let _apply_total = timing .as_ref() .map(|t| t.span_phase("rrdp_apply_snapshot_total")); - let published = apply_snapshot( + let published = apply_snapshot_from_bufread( store, notification_uri, - &snapshot_xml, + std::io::BufReader::new( + snapshot_file + .reopen() + .map_err(|e| RrdpSyncError::Fetch(format!("tempfile reopen failed: {e}")))?, + ), notif.session_id, notif.serial, )?; @@ -1154,61 +1151,236 @@ fn apply_snapshot( expected_session_id: Uuid, expected_serial: u64, ) -> Result { - let doc = parse_rrdp_xml(snapshot_xml)?; - let root = doc.root_element(); - if root.tag_name().name() != "snapshot" { - return Err(RrdpError::UnexpectedRoot(root.tag_name().name().to_string()).into()); - } - validate_root_common(&root)?; - - let got_session_id = parse_uuid_attr(&root, "session_id")?; - if got_session_id != expected_session_id { - return Err(RrdpError::SnapshotSessionIdMismatch { - expected: expected_session_id.to_string(), - got: got_session_id.to_string(), - } - .into()); - } - let got_serial = parse_u64_attr(&root, "serial")?; - if got_serial != expected_serial { - return Err(RrdpError::SnapshotSerialMismatch { - expected: expected_serial, - got: got_serial, - } - .into()); - } - - let mut published: Vec<(String, Vec)> = Vec::new(); - for publish in root - .children() - .filter(|n| n.is_element() && n.tag_name().name() == "publish") - { - let uri = publish - .attribute("uri") - .ok_or(RrdpError::PublishUriMissing)?; - let content_b64 = collect_element_text(&publish).ok_or(RrdpError::PublishContentMissing)?; - let content_b64 = strip_all_ascii_whitespace(&content_b64); - - let bytes = base64::engine::general_purpose::STANDARD - .decode(content_b64.as_bytes()) - .map_err(|e| RrdpError::PublishBase64(e.to_string()))?; - - ensure_rrdp_uri_can_be_owned_by(store, notification_uri, uri) - .map_err(RrdpSyncError::Storage)?; - published.push((uri.to_string(), bytes)); + if snapshot_xml.iter().any(|&b| b > 0x7F) { + return Err(RrdpError::NotAscii.into()); } + apply_snapshot_from_bufread( + store, + notification_uri, + std::io::Cursor::new(snapshot_xml), + expected_session_id, + expected_serial, + ) +} +fn apply_snapshot_from_bufread( + store: &RocksStore, + notification_uri: &str, + input: R, + expected_session_id: Uuid, + expected_serial: u64, +) -> Result { let previous_members: Vec = store .list_current_rrdp_source_members(notification_uri) .map_err(|e| RrdpSyncError::Storage(e.to_string()))? .into_iter() .map(|record| record.rsync_uri) .collect(); - let new_set: std::collections::HashSet<&str> = - published.iter().map(|(uri, _)| uri.as_str()).collect(); + let session_id = expected_session_id.to_string(); + let mut new_set: std::collections::HashSet = std::collections::HashSet::new(); + let mut batch_published: Vec<(String, Vec)> = + Vec::with_capacity(RRDP_SNAPSHOT_APPLY_BATCH_SIZE); + let mut published_count = 0usize; + + let mut reader = Reader::from_reader(input); + reader.config_mut().trim_text(false); + let mut buf = Vec::new(); + let mut root_seen = false; + let mut in_publish = false; + let mut publish_nested_depth = 0usize; + let mut current_publish_uri: Option = None; + let mut current_publish_text = String::new(); + + loop { + match reader.read_event_into(&mut buf) { + Ok(Event::Start(e)) => { + let local_name = e.local_name(); + let local_name = local_name.as_ref(); + if !root_seen { + root_seen = true; + if local_name != b"snapshot" { + let got = String::from_utf8_lossy(local_name).to_string(); + return Err(RrdpError::UnexpectedRoot(got).into()); + } + let mut xmlns = String::new(); + let mut version = String::new(); + let mut session_id_attr = String::new(); + let mut serial_attr = String::new(); + for attr in e.attributes().with_checks(false) { + let attr = match attr { + Ok(attr) => attr, + Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), + }; + let key = attr.key.as_ref(); + let value = attr + .decode_and_unescape_value(reader.decoder()) + .map_err(|e| RrdpError::Xml(e.to_string()))? + .into_owned(); + match key { + b"xmlns" => xmlns = value, + b"version" => version = value, + b"session_id" => session_id_attr = value, + b"serial" => serial_attr = value, + _ => {} + } + } + if xmlns != RRDP_XMLNS { + return Err(RrdpError::InvalidNamespace(xmlns).into()); + } + if version != "1" { + return Err(RrdpError::InvalidVersion(version).into()); + } + let got_session_id = Uuid::parse_str(&session_id_attr) + .map_err(|_| RrdpError::InvalidSessionId(session_id_attr.clone()))?; + if got_session_id != expected_session_id { + return Err(RrdpError::SnapshotSessionIdMismatch { + expected: expected_session_id.to_string(), + got: got_session_id.to_string(), + } + .into()); + } + let got_serial = parse_u64_str(&serial_attr)?; + if got_serial != expected_serial { + return Err(RrdpError::SnapshotSerialMismatch { + expected: expected_serial, + got: got_serial, + } + .into()); + } + } else if in_publish { + publish_nested_depth += 1; + } else if local_name == b"publish" { + let mut uri = None; + for attr in e.attributes().with_checks(false) { + let attr = match attr { + Ok(attr) => attr, + Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), + }; + if attr.key.as_ref() == b"uri" { + uri = Some( + attr.decode_and_unescape_value(reader.decoder()) + .map_err(|e| RrdpError::Xml(e.to_string()))? + .into_owned(), + ); + } + } + let uri = uri.ok_or(RrdpError::PublishUriMissing)?; + ensure_rrdp_uri_can_be_owned_by(store, notification_uri, &uri) + .map_err(RrdpSyncError::Storage)?; + in_publish = true; + publish_nested_depth = 0; + current_publish_uri = Some(uri); + current_publish_text.clear(); + } + } + Ok(Event::Empty(e)) => { + let local_name = e.local_name(); + let local_name = local_name.as_ref(); + if !root_seen { + let got = String::from_utf8_lossy(local_name).to_string(); + return Err(RrdpError::UnexpectedRoot(got).into()); + } + if local_name == b"publish" { + let mut has_uri = false; + for attr in e.attributes().with_checks(false) { + let attr = match attr { + Ok(attr) => attr, + Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), + }; + if attr.key.as_ref() == b"uri" { + has_uri = true; + break; + } + } + if !has_uri { + return Err(RrdpError::PublishUriMissing.into()); + } + return Err(RrdpError::PublishContentMissing.into()); + } + } + Ok(Event::Text(e)) => { + if in_publish && publish_nested_depth == 0 { + let text = reader + .decoder() + .decode(e.as_ref()) + .map_err(|e| RrdpError::Xml(e.to_string()))?; + current_publish_text.push_str(&text); + } + } + Ok(Event::CData(e)) => { + if in_publish && publish_nested_depth == 0 { + let text = reader + .decoder() + .decode(e.as_ref()) + .map_err(|e| RrdpError::Xml(e.to_string()))?; + current_publish_text.push_str(&text); + } + } + Ok(Event::End(e)) => { + let local_name = e.local_name(); + let local_name = local_name.as_ref(); + if in_publish { + if publish_nested_depth > 0 { + publish_nested_depth -= 1; + } else if local_name == b"publish" { + let uri = current_publish_uri + .take() + .ok_or_else(|| RrdpError::Xml("publish uri missing in state".into()))?; + let content_b64 = strip_all_ascii_whitespace(¤t_publish_text); + current_publish_text.clear(); + if content_b64.is_empty() { + return Err(RrdpError::PublishContentMissing.into()); + } + let bytes = base64::engine::general_purpose::STANDARD + .decode(content_b64.as_bytes()) + .map_err(|e| RrdpError::PublishBase64(e.to_string()))?; + new_set.insert(uri.clone()); + batch_published.push((uri, bytes)); + published_count += 1; + if batch_published.len() >= RRDP_SNAPSHOT_APPLY_BATCH_SIZE { + flush_snapshot_publish_batch( + store, + notification_uri, + &session_id, + expected_serial, + &batch_published, + )?; + batch_published.clear(); + } + in_publish = false; + } + } + } + Ok(Event::Eof) => break, + Ok(Event::Decl(_) + | Event::PI(_) + | Event::Comment(_) + | Event::DocType(_)) => {} + Err(e) => return Err(RrdpError::Xml(e.to_string()).into()), + } + buf.clear(); + } + + if !root_seen { + return Err(RrdpError::Xml("missing root element".to_string()).into()); + } + if in_publish { + return Err(RrdpError::PublishContentMissing.into()); + } + if !batch_published.is_empty() { + flush_snapshot_publish_batch( + store, + notification_uri, + &session_id, + expected_serial, + &batch_published, + )?; + batch_published.clear(); + } + let mut withdrawn: Vec<(String, Option)> = Vec::new(); for old_uri in &previous_members { - if new_set.contains(old_uri.as_str()) { + if new_set.contains(old_uri) { continue; } let previous_hash = store @@ -1225,38 +1397,9 @@ fn apply_snapshot( withdrawn.push((old_uri.clone(), previous_hash)); } - let session_id = expected_session_id.to_string(); - let prepared_raw = - prepare_raw_by_hash_evidence_batch(store, &published).map_err(RrdpSyncError::Storage)?; - let mut repository_view_entries = Vec::with_capacity(published.len() + withdrawn.len()); - let mut member_records = Vec::with_capacity(published.len() + withdrawn.len()); - let mut owner_records = Vec::with_capacity(published.len() + withdrawn.len()); - - for (uri, _bytes) in &published { - let current_hash = prepared_raw.uri_to_hash.get(uri).cloned().ok_or_else(|| { - RrdpSyncError::Storage(format!("missing raw_by_hash mapping for {uri}")) - })?; - repository_view_entries.push(build_repository_view_present_entry( - notification_uri, - uri, - ¤t_hash, - )); - member_records.push(build_rrdp_source_member_present_record( - notification_uri, - &session_id, - expected_serial, - uri, - ¤t_hash, - )); - owner_records.push(build_rrdp_uri_owner_active_record( - notification_uri, - &session_id, - expected_serial, - uri, - ¤t_hash, - )); - } - + let mut repository_view_entries = Vec::with_capacity(withdrawn.len()); + let mut member_records = Vec::with_capacity(withdrawn.len()); + let mut owner_records = Vec::with_capacity(withdrawn.len()); for (uri, previous_hash) in withdrawn { member_records.push(build_rrdp_source_member_withdrawn_record( notification_uri, @@ -1280,6 +1423,50 @@ fn apply_snapshot( )); } } + store + .put_projection_batch(&repository_view_entries, &member_records, &owner_records) + .map_err(|e| RrdpSyncError::Storage(e.to_string()))?; + + Ok(published_count) +} + +fn flush_snapshot_publish_batch( + store: &RocksStore, + notification_uri: &str, + session_id: &str, + serial: u64, + published: &[(String, Vec)], +) -> Result<(), RrdpSyncError> { + let prepared_raw = + prepare_raw_by_hash_evidence_batch(store, published).map_err(RrdpSyncError::Storage)?; + let mut repository_view_entries = Vec::with_capacity(published.len()); + let mut member_records = Vec::with_capacity(published.len()); + let mut owner_records = Vec::with_capacity(published.len()); + + for (uri, _bytes) in published { + let current_hash = prepared_raw.uri_to_hash.get(uri).cloned().ok_or_else(|| { + RrdpSyncError::Storage(format!("missing raw_by_hash mapping for {uri}")) + })?; + repository_view_entries.push(build_repository_view_present_entry( + notification_uri, + uri, + ¤t_hash, + )); + member_records.push(build_rrdp_source_member_present_record( + notification_uri, + session_id, + serial, + uri, + ¤t_hash, + )); + owner_records.push(build_rrdp_uri_owner_active_record( + notification_uri, + session_id, + serial, + uri, + ¤t_hash, + )); + } store .put_raw_by_hash_entries_batch_unchecked(&prepared_raw.entries_to_write) @@ -1288,7 +1475,81 @@ fn apply_snapshot( .put_projection_batch(&repository_view_entries, &member_records, &owner_records) .map_err(|e| RrdpSyncError::Storage(e.to_string()))?; - Ok(published.len()) + Ok(()) +} + +const SNAPSHOT_NON_ASCII_ERROR: &str = "snapshot body contains non-ASCII bytes"; + +struct SnapshotSpoolWriter<'a, W: Write> { + inner: &'a mut W, + hasher: sha2::Sha256, + bytes: u64, +} + +impl<'a, W: Write> SnapshotSpoolWriter<'a, W> { + fn new(inner: &'a mut W) -> Self { + Self { + inner, + hasher: sha2::Sha256::new(), + bytes: 0, + } + } + + fn bytes_written(&self) -> u64 { + self.bytes + } + + fn finalize_hash(self) -> [u8; 32] { + let digest = self.hasher.finalize(); + let mut out = [0u8; 32]; + out.copy_from_slice(&digest); + out + } +} + +impl Write for SnapshotSpoolWriter<'_, W> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + if buf.iter().any(|&b| b > 0x7F) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + SNAPSHOT_NON_ASCII_ERROR, + )); + } + let n = self.inner.write(buf)?; + self.hasher.update(&buf[..n]); + self.bytes += n as u64; + Ok(n) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.inner.flush() + } +} + +fn fetch_snapshot_into_tempfile( + fetcher: &dyn Fetcher, + snapshot_uri: &str, + expected_hash_sha256: &[u8; 32], +) -> Result<(tempfile::NamedTempFile, u64), RrdpSyncError> { + let mut tmp = + tempfile::NamedTempFile::new().map_err(|e| RrdpSyncError::Fetch(format!("tempfile create failed: {e}")))?; + let mut spool = SnapshotSpoolWriter::new(tmp.as_file_mut()); + let bytes_written = match fetcher.fetch_to_writer(snapshot_uri, &mut spool) { + Ok(bytes) => bytes, + Err(e) if e.contains(SNAPSHOT_NON_ASCII_ERROR) => return Err(RrdpError::NotAscii.into()), + Err(e) => return Err(RrdpSyncError::Fetch(e)), + }; + let computed = spool.finalize_hash(); + if computed.as_slice() != expected_hash_sha256.as_slice() { + return Err(RrdpError::SnapshotHashMismatch.into()); + } + tmp.as_file_mut() + .flush() + .map_err(|e| RrdpSyncError::Fetch(format!("tempfile flush failed: {e}")))?; + tmp.as_file_mut() + .seek(SeekFrom::Start(0)) + .map_err(|e| RrdpSyncError::Fetch(format!("tempfile rewind failed: {e}")))?; + Ok((tmp, bytes_written)) } fn parse_rrdp_xml(xml: &[u8]) -> Result, RrdpError> { @@ -2471,4 +2732,37 @@ mod tests { RrdpSyncError::Rrdp(RrdpError::PublishBase64(_)) )); } + + #[test] + fn apply_snapshot_handles_multiple_publish_batches() { + let tmp = tempfile::tempdir().expect("tempdir"); + let store = RocksStore::open(tmp.path()).expect("open rocksdb"); + let notif_uri = "https://example.net/notification.xml"; + let sid = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(); + + let total = RRDP_SNAPSHOT_APPLY_BATCH_SIZE + 7; + let mut xml = format!( + r#""# + ); + for i in 0..total { + let uri = format!("rsync://example.net/repo/{i:04}.roa"); + let bytes = format!("payload-{i}").into_bytes(); + let b64 = base64::engine::general_purpose::STANDARD.encode(bytes); + xml.push_str(&format!(r#"{b64}"#)); + } + xml.push_str(""); + + let published = + apply_snapshot(&store, notif_uri, xml.as_bytes(), sid, 1).expect("apply snapshot"); + assert_eq!(published, total); + + for idx in [0usize, RRDP_SNAPSHOT_APPLY_BATCH_SIZE - 1, total - 1] { + let uri = format!("rsync://example.net/repo/{idx:04}.roa"); + let got = store + .load_current_object_bytes_by_uri(&uri) + .expect("load object") + .expect("object exists"); + assert_eq!(got, format!("payload-{idx}").into_bytes()); + } + } }