20260411 apply snapahot内存优化,采用流式写文件和分块处理降低运行内存需求

This commit is contained in:
yuyr 2026-04-11 14:45:08 +08:00
parent 77fc2f1a41
commit e45830d79f
3 changed files with 523 additions and 117 deletions

View File

@ -26,10 +26,11 @@ toml = "0.8.20"
rocksdb = { version = "0.22.0", optional = true, default-features = false, features = ["lz4"] }
serde_cbor = "0.11.2"
roxmltree = "0.20.0"
quick-xml = "0.37.2"
uuid = { version = "1.7.0", features = ["v4"] }
reqwest = { version = "0.12.12", default-features = false, features = ["blocking", "rustls-tls"] }
pprof = { version = "0.14.1", optional = true, features = ["flamegraph", "prost-codec"] }
flate2 = { version = "1.0.35", optional = true }
tempfile = "3.16.0"
[dev-dependencies]
tempfile = "3.16.0"

View File

@ -1,3 +1,4 @@
use std::io::Write;
use std::time::Duration;
use reqwest::blocking::Client;
@ -191,6 +192,116 @@ impl Fetcher for BlockingHttpFetcher {
fn fetch(&self, uri: &str) -> Result<Vec<u8>, String> {
self.fetch_bytes(uri)
}
fn fetch_to_writer(&self, uri: &str, out: &mut dyn Write) -> Result<u64, String> {
let started = std::time::Instant::now();
let (client, timeout_profile, timeout_value) = self.client_for_uri(uri);
let resp = client.get(uri).send().map_err(|e| {
let msg = format!("http request failed: {e}");
crate::progress_log::emit(
"http_fetch_failed",
serde_json::json!({
"uri": uri,
"stage": "request",
"timeout_profile": timeout_profile,
"request_timeout_ms": timeout_value.as_millis() as u64,
"duration_ms": started.elapsed().as_millis() as u64,
"error": msg,
}),
);
msg
})?;
let status = resp.status();
let headers = resp.headers().clone();
if !status.is_success() {
let body_preview = resp
.text()
.ok()
.map(|text| text.chars().take(160).collect::<String>());
let body_prefix = body_preview
.clone()
.unwrap_or_else(|| "<unavailable>".to_string());
let msg = format!(
"http status {status}; content_type={}; content_encoding={}; content_length={}; transfer_encoding={}; body_prefix={}",
header_value(&headers, "content-type"),
header_value(&headers, "content-encoding"),
header_value(&headers, "content-length"),
header_value(&headers, "transfer-encoding"),
body_prefix,
);
crate::progress_log::emit(
"http_fetch_failed",
serde_json::json!({
"uri": uri,
"stage": "status",
"timeout_profile": timeout_profile,
"request_timeout_ms": timeout_value.as_millis() as u64,
"duration_ms": started.elapsed().as_millis() as u64,
"status": status.as_u16(),
"content_type": header_value_opt(&headers, "content-type"),
"content_encoding": header_value_opt(&headers, "content-encoding"),
"content_length": header_value_opt(&headers, "content-length"),
"transfer_encoding": header_value_opt(&headers, "transfer-encoding"),
"body_prefix": body_preview,
"error": msg,
}),
);
return Err(msg);
}
let mut resp = resp;
match resp.copy_to(out) {
Ok(bytes) => {
let duration_ms = started.elapsed().as_millis() as u64;
if (duration_ms as f64) / 1000.0 >= crate::progress_log::slow_threshold_secs() {
crate::progress_log::emit(
"http_fetch_slow",
serde_json::json!({
"uri": uri,
"status": status.as_u16(),
"timeout_profile": timeout_profile,
"request_timeout_ms": timeout_value.as_millis() as u64,
"duration_ms": duration_ms,
"bytes": bytes,
"content_type": header_value_opt(&headers, "content-type"),
"content_encoding": header_value_opt(&headers, "content-encoding"),
"content_length": header_value_opt(&headers, "content-length"),
"transfer_encoding": header_value_opt(&headers, "transfer-encoding"),
}),
);
}
Ok(bytes)
}
Err(e) => {
let msg = format!(
"http stream body failed: {e}; status={}; content_type={}; content_encoding={}; content_length={}; transfer_encoding={}",
status,
header_value(&headers, "content-type"),
header_value(&headers, "content-encoding"),
header_value(&headers, "content-length"),
header_value(&headers, "transfer-encoding"),
);
crate::progress_log::emit(
"http_fetch_failed",
serde_json::json!({
"uri": uri,
"stage": "stream_body",
"timeout_profile": timeout_profile,
"request_timeout_ms": timeout_value.as_millis() as u64,
"duration_ms": started.elapsed().as_millis() as u64,
"status": status.as_u16(),
"content_type": header_value_opt(&headers, "content-type"),
"content_encoding": header_value_opt(&headers, "content-encoding"),
"content_length": header_value_opt(&headers, "content-length"),
"transfer_encoding": header_value_opt(&headers, "transfer-encoding"),
"error": msg,
}),
);
Err(msg)
}
}
}
}
fn header_value(headers: &HeaderMap, name: &str) -> String {

View File

@ -12,11 +12,15 @@ use crate::sync::store_projection::{
update_rrdp_source_record_on_success, upsert_raw_by_hash_evidence,
};
use base64::Engine;
use quick_xml::events::Event;
use quick_xml::Reader;
use serde::{Deserialize, Serialize};
use sha2::Digest;
use std::io::{BufRead, Seek, SeekFrom, Write};
use uuid::Uuid;
const RRDP_XMLNS: &str = "http://www.ripe.net/rpki/rrdp";
const RRDP_SNAPSHOT_APPLY_BATCH_SIZE: usize = 1024;
#[derive(Debug, thiserror::Error)]
pub enum RrdpError {
@ -190,6 +194,13 @@ pub type RrdpSyncResult<T> = Result<T, RrdpSyncError>;
pub trait Fetcher {
fn fetch(&self, uri: &str) -> Result<Vec<u8>, String>;
fn fetch_to_writer(&self, uri: &str, out: &mut dyn Write) -> Result<u64, String> {
let bytes = self.fetch(uri)?;
out.write_all(&bytes)
.map_err(|e| format!("write sink failed: {e}"))?;
Ok(bytes.len() as u64)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
@ -548,19 +559,20 @@ fn sync_from_notification_snapshot_inner(
.map(|t| t.span_phase("rrdp_fetch_snapshot_total"));
let mut dl_span = download_log
.map(|dl| dl.span_download(AuditDownloadKind::RrdpSnapshot, &notif.snapshot_uri));
let snapshot_xml = match fetcher.fetch(&notif.snapshot_uri) {
let (snapshot_file, _snapshot_bytes) =
match fetch_snapshot_into_tempfile(fetcher, &notif.snapshot_uri, &notif.snapshot_hash_sha256) {
Ok(v) => {
if let Some(t) = timing.as_ref() {
t.record_count("rrdp_snapshot_fetch_ok_total", 1);
t.record_count("rrdp_snapshot_bytes_total", v.len() as u64);
t.record_count("rrdp_snapshot_bytes_total", v.1);
}
if let Some(s) = dl_span.as_mut() {
s.set_bytes(v.len() as u64);
s.set_bytes(v.1);
s.set_ok();
}
v
}
Err(e) => {
Err(RrdpSyncError::Fetch(e)) => {
if let Some(t) = timing.as_ref() {
t.record_count("rrdp_snapshot_fetch_fail_total", 1);
}
@ -569,33 +581,25 @@ fn sync_from_notification_snapshot_inner(
}
return Err(RrdpSyncError::Fetch(e));
}
Err(e) => return Err(e),
};
drop(_fetch_step);
drop(_fetch_total);
let _hash_step = timing
.as_ref()
.map(|t| t.span_rrdp_repo_step(notification_uri, "hash_snapshot"));
let _hash_total = timing
.as_ref()
.map(|t| t.span_phase("rrdp_hash_snapshot_total"));
let computed = sha2::Sha256::digest(&snapshot_xml);
if computed.as_slice() != notif.snapshot_hash_sha256.as_slice() {
return Err(RrdpError::SnapshotHashMismatch.into());
}
drop(_hash_step);
drop(_hash_total);
let _apply_step = timing
.as_ref()
.map(|t| t.span_rrdp_repo_step(notification_uri, "apply_snapshot"));
let _apply_total = timing
.as_ref()
.map(|t| t.span_phase("rrdp_apply_snapshot_total"));
let published = apply_snapshot(
let published = apply_snapshot_from_bufread(
store,
notification_uri,
&snapshot_xml,
std::io::BufReader::new(
snapshot_file
.reopen()
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile reopen failed: {e}")))?,
),
notif.session_id,
notif.serial,
)?;
@ -866,19 +870,20 @@ fn sync_from_notification_inner(
.map(|t| t.span_phase("rrdp_fetch_snapshot_total"));
let mut dl_span = download_log
.map(|dl| dl.span_download(AuditDownloadKind::RrdpSnapshot, &notif.snapshot_uri));
let snapshot_xml = match fetcher.fetch(&notif.snapshot_uri) {
let (snapshot_file, _snapshot_bytes) =
match fetch_snapshot_into_tempfile(fetcher, &notif.snapshot_uri, &notif.snapshot_hash_sha256) {
Ok(v) => {
if let Some(t) = timing.as_ref() {
t.record_count("rrdp_snapshot_fetch_ok_total", 1);
t.record_count("rrdp_snapshot_bytes_total", v.len() as u64);
t.record_count("rrdp_snapshot_bytes_total", v.1);
}
if let Some(s) = dl_span.as_mut() {
s.set_bytes(v.len() as u64);
s.set_bytes(v.1);
s.set_ok();
}
v
}
Err(e) => {
Err(RrdpSyncError::Fetch(e)) => {
if let Some(t) = timing.as_ref() {
t.record_count("rrdp_snapshot_fetch_fail_total", 1);
}
@ -887,33 +892,25 @@ fn sync_from_notification_inner(
}
return Err(RrdpSyncError::Fetch(e));
}
Err(e) => return Err(e),
};
drop(_fetch_step);
drop(_fetch_total);
let _hash_step = timing
.as_ref()
.map(|t| t.span_rrdp_repo_step(notification_uri, "hash_snapshot"));
let _hash_total = timing
.as_ref()
.map(|t| t.span_phase("rrdp_hash_snapshot_total"));
let computed = sha2::Sha256::digest(&snapshot_xml);
if computed.as_slice() != notif.snapshot_hash_sha256.as_slice() {
return Err(RrdpError::SnapshotHashMismatch.into());
}
drop(_hash_step);
drop(_hash_total);
let _apply_step = timing
.as_ref()
.map(|t| t.span_rrdp_repo_step(notification_uri, "apply_snapshot"));
let _apply_total = timing
.as_ref()
.map(|t| t.span_phase("rrdp_apply_snapshot_total"));
let published = apply_snapshot(
let published = apply_snapshot_from_bufread(
store,
notification_uri,
&snapshot_xml,
std::io::BufReader::new(
snapshot_file
.reopen()
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile reopen failed: {e}")))?,
),
notif.session_id,
notif.serial,
)?;
@ -1154,61 +1151,236 @@ fn apply_snapshot(
expected_session_id: Uuid,
expected_serial: u64,
) -> Result<usize, RrdpSyncError> {
let doc = parse_rrdp_xml(snapshot_xml)?;
let root = doc.root_element();
if root.tag_name().name() != "snapshot" {
return Err(RrdpError::UnexpectedRoot(root.tag_name().name().to_string()).into());
}
validate_root_common(&root)?;
let got_session_id = parse_uuid_attr(&root, "session_id")?;
if got_session_id != expected_session_id {
return Err(RrdpError::SnapshotSessionIdMismatch {
expected: expected_session_id.to_string(),
got: got_session_id.to_string(),
}
.into());
}
let got_serial = parse_u64_attr(&root, "serial")?;
if got_serial != expected_serial {
return Err(RrdpError::SnapshotSerialMismatch {
expected: expected_serial,
got: got_serial,
}
.into());
}
let mut published: Vec<(String, Vec<u8>)> = Vec::new();
for publish in root
.children()
.filter(|n| n.is_element() && n.tag_name().name() == "publish")
{
let uri = publish
.attribute("uri")
.ok_or(RrdpError::PublishUriMissing)?;
let content_b64 = collect_element_text(&publish).ok_or(RrdpError::PublishContentMissing)?;
let content_b64 = strip_all_ascii_whitespace(&content_b64);
let bytes = base64::engine::general_purpose::STANDARD
.decode(content_b64.as_bytes())
.map_err(|e| RrdpError::PublishBase64(e.to_string()))?;
ensure_rrdp_uri_can_be_owned_by(store, notification_uri, uri)
.map_err(RrdpSyncError::Storage)?;
published.push((uri.to_string(), bytes));
if snapshot_xml.iter().any(|&b| b > 0x7F) {
return Err(RrdpError::NotAscii.into());
}
apply_snapshot_from_bufread(
store,
notification_uri,
std::io::Cursor::new(snapshot_xml),
expected_session_id,
expected_serial,
)
}
fn apply_snapshot_from_bufread<R: BufRead>(
store: &RocksStore,
notification_uri: &str,
input: R,
expected_session_id: Uuid,
expected_serial: u64,
) -> Result<usize, RrdpSyncError> {
let previous_members: Vec<String> = store
.list_current_rrdp_source_members(notification_uri)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?
.into_iter()
.map(|record| record.rsync_uri)
.collect();
let new_set: std::collections::HashSet<&str> =
published.iter().map(|(uri, _)| uri.as_str()).collect();
let session_id = expected_session_id.to_string();
let mut new_set: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut batch_published: Vec<(String, Vec<u8>)> =
Vec::with_capacity(RRDP_SNAPSHOT_APPLY_BATCH_SIZE);
let mut published_count = 0usize;
let mut reader = Reader::from_reader(input);
reader.config_mut().trim_text(false);
let mut buf = Vec::new();
let mut root_seen = false;
let mut in_publish = false;
let mut publish_nested_depth = 0usize;
let mut current_publish_uri: Option<String> = None;
let mut current_publish_text = String::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(e)) => {
let local_name = e.local_name();
let local_name = local_name.as_ref();
if !root_seen {
root_seen = true;
if local_name != b"snapshot" {
let got = String::from_utf8_lossy(local_name).to_string();
return Err(RrdpError::UnexpectedRoot(got).into());
}
let mut xmlns = String::new();
let mut version = String::new();
let mut session_id_attr = String::new();
let mut serial_attr = String::new();
for attr in e.attributes().with_checks(false) {
let attr = match attr {
Ok(attr) => attr,
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
};
let key = attr.key.as_ref();
let value = attr
.decode_and_unescape_value(reader.decoder())
.map_err(|e| RrdpError::Xml(e.to_string()))?
.into_owned();
match key {
b"xmlns" => xmlns = value,
b"version" => version = value,
b"session_id" => session_id_attr = value,
b"serial" => serial_attr = value,
_ => {}
}
}
if xmlns != RRDP_XMLNS {
return Err(RrdpError::InvalidNamespace(xmlns).into());
}
if version != "1" {
return Err(RrdpError::InvalidVersion(version).into());
}
let got_session_id = Uuid::parse_str(&session_id_attr)
.map_err(|_| RrdpError::InvalidSessionId(session_id_attr.clone()))?;
if got_session_id != expected_session_id {
return Err(RrdpError::SnapshotSessionIdMismatch {
expected: expected_session_id.to_string(),
got: got_session_id.to_string(),
}
.into());
}
let got_serial = parse_u64_str(&serial_attr)?;
if got_serial != expected_serial {
return Err(RrdpError::SnapshotSerialMismatch {
expected: expected_serial,
got: got_serial,
}
.into());
}
} else if in_publish {
publish_nested_depth += 1;
} else if local_name == b"publish" {
let mut uri = None;
for attr in e.attributes().with_checks(false) {
let attr = match attr {
Ok(attr) => attr,
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
};
if attr.key.as_ref() == b"uri" {
uri = Some(
attr.decode_and_unescape_value(reader.decoder())
.map_err(|e| RrdpError::Xml(e.to_string()))?
.into_owned(),
);
}
}
let uri = uri.ok_or(RrdpError::PublishUriMissing)?;
ensure_rrdp_uri_can_be_owned_by(store, notification_uri, &uri)
.map_err(RrdpSyncError::Storage)?;
in_publish = true;
publish_nested_depth = 0;
current_publish_uri = Some(uri);
current_publish_text.clear();
}
}
Ok(Event::Empty(e)) => {
let local_name = e.local_name();
let local_name = local_name.as_ref();
if !root_seen {
let got = String::from_utf8_lossy(local_name).to_string();
return Err(RrdpError::UnexpectedRoot(got).into());
}
if local_name == b"publish" {
let mut has_uri = false;
for attr in e.attributes().with_checks(false) {
let attr = match attr {
Ok(attr) => attr,
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
};
if attr.key.as_ref() == b"uri" {
has_uri = true;
break;
}
}
if !has_uri {
return Err(RrdpError::PublishUriMissing.into());
}
return Err(RrdpError::PublishContentMissing.into());
}
}
Ok(Event::Text(e)) => {
if in_publish && publish_nested_depth == 0 {
let text = reader
.decoder()
.decode(e.as_ref())
.map_err(|e| RrdpError::Xml(e.to_string()))?;
current_publish_text.push_str(&text);
}
}
Ok(Event::CData(e)) => {
if in_publish && publish_nested_depth == 0 {
let text = reader
.decoder()
.decode(e.as_ref())
.map_err(|e| RrdpError::Xml(e.to_string()))?;
current_publish_text.push_str(&text);
}
}
Ok(Event::End(e)) => {
let local_name = e.local_name();
let local_name = local_name.as_ref();
if in_publish {
if publish_nested_depth > 0 {
publish_nested_depth -= 1;
} else if local_name == b"publish" {
let uri = current_publish_uri
.take()
.ok_or_else(|| RrdpError::Xml("publish uri missing in state".into()))?;
let content_b64 = strip_all_ascii_whitespace(&current_publish_text);
current_publish_text.clear();
if content_b64.is_empty() {
return Err(RrdpError::PublishContentMissing.into());
}
let bytes = base64::engine::general_purpose::STANDARD
.decode(content_b64.as_bytes())
.map_err(|e| RrdpError::PublishBase64(e.to_string()))?;
new_set.insert(uri.clone());
batch_published.push((uri, bytes));
published_count += 1;
if batch_published.len() >= RRDP_SNAPSHOT_APPLY_BATCH_SIZE {
flush_snapshot_publish_batch(
store,
notification_uri,
&session_id,
expected_serial,
&batch_published,
)?;
batch_published.clear();
}
in_publish = false;
}
}
}
Ok(Event::Eof) => break,
Ok(Event::Decl(_)
| Event::PI(_)
| Event::Comment(_)
| Event::DocType(_)) => {}
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
}
buf.clear();
}
if !root_seen {
return Err(RrdpError::Xml("missing root element".to_string()).into());
}
if in_publish {
return Err(RrdpError::PublishContentMissing.into());
}
if !batch_published.is_empty() {
flush_snapshot_publish_batch(
store,
notification_uri,
&session_id,
expected_serial,
&batch_published,
)?;
batch_published.clear();
}
let mut withdrawn: Vec<(String, Option<String>)> = Vec::new();
for old_uri in &previous_members {
if new_set.contains(old_uri.as_str()) {
if new_set.contains(old_uri) {
continue;
}
let previous_hash = store
@ -1225,38 +1397,9 @@ fn apply_snapshot(
withdrawn.push((old_uri.clone(), previous_hash));
}
let session_id = expected_session_id.to_string();
let prepared_raw =
prepare_raw_by_hash_evidence_batch(store, &published).map_err(RrdpSyncError::Storage)?;
let mut repository_view_entries = Vec::with_capacity(published.len() + withdrawn.len());
let mut member_records = Vec::with_capacity(published.len() + withdrawn.len());
let mut owner_records = Vec::with_capacity(published.len() + withdrawn.len());
for (uri, _bytes) in &published {
let current_hash = prepared_raw.uri_to_hash.get(uri).cloned().ok_or_else(|| {
RrdpSyncError::Storage(format!("missing raw_by_hash mapping for {uri}"))
})?;
repository_view_entries.push(build_repository_view_present_entry(
notification_uri,
uri,
&current_hash,
));
member_records.push(build_rrdp_source_member_present_record(
notification_uri,
&session_id,
expected_serial,
uri,
&current_hash,
));
owner_records.push(build_rrdp_uri_owner_active_record(
notification_uri,
&session_id,
expected_serial,
uri,
&current_hash,
));
}
let mut repository_view_entries = Vec::with_capacity(withdrawn.len());
let mut member_records = Vec::with_capacity(withdrawn.len());
let mut owner_records = Vec::with_capacity(withdrawn.len());
for (uri, previous_hash) in withdrawn {
member_records.push(build_rrdp_source_member_withdrawn_record(
notification_uri,
@ -1280,6 +1423,50 @@ fn apply_snapshot(
));
}
}
store
.put_projection_batch(&repository_view_entries, &member_records, &owner_records)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
Ok(published_count)
}
fn flush_snapshot_publish_batch(
store: &RocksStore,
notification_uri: &str,
session_id: &str,
serial: u64,
published: &[(String, Vec<u8>)],
) -> Result<(), RrdpSyncError> {
let prepared_raw =
prepare_raw_by_hash_evidence_batch(store, published).map_err(RrdpSyncError::Storage)?;
let mut repository_view_entries = Vec::with_capacity(published.len());
let mut member_records = Vec::with_capacity(published.len());
let mut owner_records = Vec::with_capacity(published.len());
for (uri, _bytes) in published {
let current_hash = prepared_raw.uri_to_hash.get(uri).cloned().ok_or_else(|| {
RrdpSyncError::Storage(format!("missing raw_by_hash mapping for {uri}"))
})?;
repository_view_entries.push(build_repository_view_present_entry(
notification_uri,
uri,
&current_hash,
));
member_records.push(build_rrdp_source_member_present_record(
notification_uri,
session_id,
serial,
uri,
&current_hash,
));
owner_records.push(build_rrdp_uri_owner_active_record(
notification_uri,
session_id,
serial,
uri,
&current_hash,
));
}
store
.put_raw_by_hash_entries_batch_unchecked(&prepared_raw.entries_to_write)
@ -1288,7 +1475,81 @@ fn apply_snapshot(
.put_projection_batch(&repository_view_entries, &member_records, &owner_records)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
Ok(published.len())
Ok(())
}
const SNAPSHOT_NON_ASCII_ERROR: &str = "snapshot body contains non-ASCII bytes";
struct SnapshotSpoolWriter<'a, W: Write> {
inner: &'a mut W,
hasher: sha2::Sha256,
bytes: u64,
}
impl<'a, W: Write> SnapshotSpoolWriter<'a, W> {
fn new(inner: &'a mut W) -> Self {
Self {
inner,
hasher: sha2::Sha256::new(),
bytes: 0,
}
}
fn bytes_written(&self) -> u64 {
self.bytes
}
fn finalize_hash(self) -> [u8; 32] {
let digest = self.hasher.finalize();
let mut out = [0u8; 32];
out.copy_from_slice(&digest);
out
}
}
impl<W: Write> Write for SnapshotSpoolWriter<'_, W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.iter().any(|&b| b > 0x7F) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
SNAPSHOT_NON_ASCII_ERROR,
));
}
let n = self.inner.write(buf)?;
self.hasher.update(&buf[..n]);
self.bytes += n as u64;
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
fn fetch_snapshot_into_tempfile(
fetcher: &dyn Fetcher,
snapshot_uri: &str,
expected_hash_sha256: &[u8; 32],
) -> Result<(tempfile::NamedTempFile, u64), RrdpSyncError> {
let mut tmp =
tempfile::NamedTempFile::new().map_err(|e| RrdpSyncError::Fetch(format!("tempfile create failed: {e}")))?;
let mut spool = SnapshotSpoolWriter::new(tmp.as_file_mut());
let bytes_written = match fetcher.fetch_to_writer(snapshot_uri, &mut spool) {
Ok(bytes) => bytes,
Err(e) if e.contains(SNAPSHOT_NON_ASCII_ERROR) => return Err(RrdpError::NotAscii.into()),
Err(e) => return Err(RrdpSyncError::Fetch(e)),
};
let computed = spool.finalize_hash();
if computed.as_slice() != expected_hash_sha256.as_slice() {
return Err(RrdpError::SnapshotHashMismatch.into());
}
tmp.as_file_mut()
.flush()
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile flush failed: {e}")))?;
tmp.as_file_mut()
.seek(SeekFrom::Start(0))
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile rewind failed: {e}")))?;
Ok((tmp, bytes_written))
}
fn parse_rrdp_xml(xml: &[u8]) -> Result<roxmltree::Document<'_>, RrdpError> {
@ -2471,4 +2732,37 @@ mod tests {
RrdpSyncError::Rrdp(RrdpError::PublishBase64(_))
));
}
#[test]
fn apply_snapshot_handles_multiple_publish_batches() {
let tmp = tempfile::tempdir().expect("tempdir");
let store = RocksStore::open(tmp.path()).expect("open rocksdb");
let notif_uri = "https://example.net/notification.xml";
let sid = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
let total = RRDP_SNAPSHOT_APPLY_BATCH_SIZE + 7;
let mut xml = format!(
r#"<snapshot xmlns="{RRDP_XMLNS}" version="1" session_id="{sid}" serial="1">"#
);
for i in 0..total {
let uri = format!("rsync://example.net/repo/{i:04}.roa");
let bytes = format!("payload-{i}").into_bytes();
let b64 = base64::engine::general_purpose::STANDARD.encode(bytes);
xml.push_str(&format!(r#"<publish uri="{uri}">{b64}</publish>"#));
}
xml.push_str("</snapshot>");
let published =
apply_snapshot(&store, notif_uri, xml.as_bytes(), sid, 1).expect("apply snapshot");
assert_eq!(published, total);
for idx in [0usize, RRDP_SNAPSHOT_APPLY_BATCH_SIZE - 1, total - 1] {
let uri = format!("rsync://example.net/repo/{idx:04}.roa");
let got = store
.load_current_object_bytes_by_uri(&uri)
.expect("load object")
.expect("object exists");
assert_eq!(got, format!("payload-{idx}").into_bytes());
}
}
}