446 lines
17 KiB
Rust
446 lines
17 KiB
Rust
use base64::Engine;
|
|
use quick_xml::Reader;
|
|
use quick_xml::events::Event;
|
|
use sha2::Digest;
|
|
use std::io::{BufRead, Seek, SeekFrom, Write};
|
|
use uuid::Uuid;
|
|
|
|
use crate::current_repo_index::CurrentRepoIndexHandle;
|
|
use crate::storage::RocksStore;
|
|
use crate::sync::store_projection::{
|
|
build_repository_view_present_entry, build_repository_view_withdrawn_entry,
|
|
build_rrdp_source_member_present_record, build_rrdp_source_member_withdrawn_record,
|
|
build_rrdp_uri_owner_active_record, build_rrdp_uri_owner_withdrawn_record, compute_sha256_hex,
|
|
current_rrdp_owner_is, ensure_rrdp_uri_can_be_owned_by, prepare_repo_bytes_batch,
|
|
};
|
|
|
|
use super::{
|
|
Fetcher, RRDP_SNAPSHOT_APPLY_BATCH_SIZE, RRDP_XMLNS, RrdpError, RrdpSyncError, parse_u64_str,
|
|
strip_all_ascii_whitespace,
|
|
};
|
|
|
|
#[cfg(test)]
|
|
pub(super) fn apply_snapshot(
|
|
store: &RocksStore,
|
|
notification_uri: &str,
|
|
current_repo_index: Option<&CurrentRepoIndexHandle>,
|
|
snapshot_xml: &[u8],
|
|
expected_session_id: Uuid,
|
|
expected_serial: u64,
|
|
) -> Result<usize, RrdpSyncError> {
|
|
if snapshot_xml.iter().any(|&b| b > 0x7F) {
|
|
return Err(RrdpError::NotAscii.into());
|
|
}
|
|
apply_snapshot_from_bufread(
|
|
store,
|
|
notification_uri,
|
|
current_repo_index,
|
|
std::io::Cursor::new(snapshot_xml),
|
|
expected_session_id,
|
|
expected_serial,
|
|
)
|
|
}
|
|
|
|
pub(super) fn apply_snapshot_from_bufread<R: BufRead>(
|
|
store: &RocksStore,
|
|
notification_uri: &str,
|
|
current_repo_index: Option<&CurrentRepoIndexHandle>,
|
|
input: R,
|
|
expected_session_id: Uuid,
|
|
expected_serial: u64,
|
|
) -> Result<usize, RrdpSyncError> {
|
|
let previous_members: Vec<String> = store
|
|
.list_current_rrdp_source_members(notification_uri)
|
|
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?
|
|
.into_iter()
|
|
.map(|record| record.rsync_uri)
|
|
.collect();
|
|
let session_id = expected_session_id.to_string();
|
|
let mut new_set: std::collections::HashSet<String> = std::collections::HashSet::new();
|
|
let mut batch_published: Vec<(String, Vec<u8>)> =
|
|
Vec::with_capacity(RRDP_SNAPSHOT_APPLY_BATCH_SIZE);
|
|
let mut published_count = 0usize;
|
|
|
|
let mut reader = Reader::from_reader(input);
|
|
reader.config_mut().trim_text(false);
|
|
let mut buf = Vec::new();
|
|
let mut root_seen = false;
|
|
let mut in_publish = false;
|
|
let mut publish_nested_depth = 0usize;
|
|
let mut current_publish_uri: Option<String> = None;
|
|
let mut current_publish_text = String::new();
|
|
|
|
loop {
|
|
match reader.read_event_into(&mut buf) {
|
|
Ok(Event::Start(e)) => {
|
|
let local_name = e.local_name();
|
|
let local_name = local_name.as_ref();
|
|
if !root_seen {
|
|
root_seen = true;
|
|
if local_name != b"snapshot" {
|
|
let got = String::from_utf8_lossy(local_name).to_string();
|
|
return Err(RrdpError::UnexpectedRoot(got).into());
|
|
}
|
|
let mut xmlns = String::new();
|
|
let mut version = String::new();
|
|
let mut session_id_attr = String::new();
|
|
let mut serial_attr = String::new();
|
|
for attr in e.attributes().with_checks(false) {
|
|
let attr = match attr {
|
|
Ok(attr) => attr,
|
|
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
|
|
};
|
|
let key = attr.key.as_ref();
|
|
let value = attr
|
|
.decode_and_unescape_value(reader.decoder())
|
|
.map_err(|e| RrdpError::Xml(e.to_string()))?
|
|
.into_owned();
|
|
match key {
|
|
b"xmlns" => xmlns = value,
|
|
b"version" => version = value,
|
|
b"session_id" => session_id_attr = value,
|
|
b"serial" => serial_attr = value,
|
|
_ => {}
|
|
}
|
|
}
|
|
if xmlns != RRDP_XMLNS {
|
|
return Err(RrdpError::InvalidNamespace(xmlns).into());
|
|
}
|
|
if version != "1" {
|
|
return Err(RrdpError::InvalidVersion(version).into());
|
|
}
|
|
let got_session_id = Uuid::parse_str(&session_id_attr)
|
|
.map_err(|_| RrdpError::InvalidSessionId(session_id_attr.clone()))?;
|
|
if got_session_id != expected_session_id {
|
|
return Err(RrdpError::SnapshotSessionIdMismatch {
|
|
expected: expected_session_id.to_string(),
|
|
got: got_session_id.to_string(),
|
|
}
|
|
.into());
|
|
}
|
|
let got_serial = parse_u64_str(&serial_attr)?;
|
|
if got_serial != expected_serial {
|
|
return Err(RrdpError::SnapshotSerialMismatch {
|
|
expected: expected_serial,
|
|
got: got_serial,
|
|
}
|
|
.into());
|
|
}
|
|
} else if in_publish {
|
|
publish_nested_depth += 1;
|
|
} else if local_name == b"publish" {
|
|
let mut uri = None;
|
|
for attr in e.attributes().with_checks(false) {
|
|
let attr = match attr {
|
|
Ok(attr) => attr,
|
|
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
|
|
};
|
|
if attr.key.as_ref() == b"uri" {
|
|
uri = Some(
|
|
attr.decode_and_unescape_value(reader.decoder())
|
|
.map_err(|e| RrdpError::Xml(e.to_string()))?
|
|
.into_owned(),
|
|
);
|
|
}
|
|
}
|
|
let uri = uri.ok_or(RrdpError::PublishUriMissing)?;
|
|
ensure_rrdp_uri_can_be_owned_by(store, notification_uri, &uri)
|
|
.map_err(RrdpSyncError::Storage)?;
|
|
in_publish = true;
|
|
publish_nested_depth = 0;
|
|
current_publish_uri = Some(uri);
|
|
current_publish_text.clear();
|
|
}
|
|
}
|
|
Ok(Event::Empty(e)) => {
|
|
let local_name = e.local_name();
|
|
let local_name = local_name.as_ref();
|
|
if !root_seen {
|
|
let got = String::from_utf8_lossy(local_name).to_string();
|
|
return Err(RrdpError::UnexpectedRoot(got).into());
|
|
}
|
|
if local_name == b"publish" {
|
|
let mut has_uri = false;
|
|
for attr in e.attributes().with_checks(false) {
|
|
let attr = match attr {
|
|
Ok(attr) => attr,
|
|
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
|
|
};
|
|
if attr.key.as_ref() == b"uri" {
|
|
has_uri = true;
|
|
break;
|
|
}
|
|
}
|
|
if !has_uri {
|
|
return Err(RrdpError::PublishUriMissing.into());
|
|
}
|
|
return Err(RrdpError::PublishContentMissing.into());
|
|
}
|
|
}
|
|
Ok(Event::Text(e)) => {
|
|
if in_publish && publish_nested_depth == 0 {
|
|
let text = reader
|
|
.decoder()
|
|
.decode(e.as_ref())
|
|
.map_err(|e| RrdpError::Xml(e.to_string()))?;
|
|
current_publish_text.push_str(&text);
|
|
}
|
|
}
|
|
Ok(Event::CData(e)) => {
|
|
if in_publish && publish_nested_depth == 0 {
|
|
let text = reader
|
|
.decoder()
|
|
.decode(e.as_ref())
|
|
.map_err(|e| RrdpError::Xml(e.to_string()))?;
|
|
current_publish_text.push_str(&text);
|
|
}
|
|
}
|
|
Ok(Event::End(e)) => {
|
|
let local_name = e.local_name();
|
|
let local_name = local_name.as_ref();
|
|
if in_publish {
|
|
if publish_nested_depth > 0 {
|
|
publish_nested_depth -= 1;
|
|
} else if local_name == b"publish" {
|
|
let uri = current_publish_uri
|
|
.take()
|
|
.ok_or_else(|| RrdpError::Xml("publish uri missing in state".into()))?;
|
|
let content_b64 = strip_all_ascii_whitespace(¤t_publish_text);
|
|
current_publish_text.clear();
|
|
if content_b64.is_empty() {
|
|
return Err(RrdpError::PublishContentMissing.into());
|
|
}
|
|
let bytes = base64::engine::general_purpose::STANDARD
|
|
.decode(content_b64.as_bytes())
|
|
.map_err(|e| RrdpError::PublishBase64(e.to_string()))?;
|
|
new_set.insert(uri.clone());
|
|
batch_published.push((uri, bytes));
|
|
published_count += 1;
|
|
if batch_published.len() >= RRDP_SNAPSHOT_APPLY_BATCH_SIZE {
|
|
flush_snapshot_publish_batch(
|
|
store,
|
|
notification_uri,
|
|
current_repo_index,
|
|
&session_id,
|
|
expected_serial,
|
|
&batch_published,
|
|
)?;
|
|
batch_published.clear();
|
|
}
|
|
in_publish = false;
|
|
}
|
|
}
|
|
}
|
|
Ok(Event::Eof) => break,
|
|
Ok(Event::Decl(_) | Event::PI(_) | Event::Comment(_) | Event::DocType(_)) => {}
|
|
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
|
|
}
|
|
buf.clear();
|
|
}
|
|
|
|
if !root_seen {
|
|
return Err(RrdpError::Xml("missing root element".to_string()).into());
|
|
}
|
|
if in_publish {
|
|
return Err(RrdpError::PublishContentMissing.into());
|
|
}
|
|
if !batch_published.is_empty() {
|
|
flush_snapshot_publish_batch(
|
|
store,
|
|
notification_uri,
|
|
current_repo_index,
|
|
&session_id,
|
|
expected_serial,
|
|
&batch_published,
|
|
)?;
|
|
batch_published.clear();
|
|
}
|
|
|
|
let mut withdrawn: Vec<(String, Option<String>)> = Vec::new();
|
|
for old_uri in &previous_members {
|
|
if new_set.contains(old_uri) {
|
|
continue;
|
|
}
|
|
let previous_hash = store
|
|
.get_repository_view_entry(old_uri)
|
|
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?
|
|
.and_then(|entry| entry.current_hash)
|
|
.or_else(|| {
|
|
store
|
|
.load_current_object_bytes_by_uri(old_uri)
|
|
.ok()
|
|
.flatten()
|
|
.map(|bytes| compute_sha256_hex(&bytes))
|
|
});
|
|
withdrawn.push((old_uri.clone(), previous_hash));
|
|
}
|
|
|
|
let mut repository_view_entries = Vec::with_capacity(withdrawn.len());
|
|
let mut member_records = Vec::with_capacity(withdrawn.len());
|
|
let mut owner_records = Vec::with_capacity(withdrawn.len());
|
|
for (uri, previous_hash) in withdrawn {
|
|
member_records.push(build_rrdp_source_member_withdrawn_record(
|
|
notification_uri,
|
|
&session_id,
|
|
expected_serial,
|
|
&uri,
|
|
previous_hash.clone(),
|
|
));
|
|
if current_rrdp_owner_is(store, notification_uri, &uri).map_err(RrdpSyncError::Storage)? {
|
|
repository_view_entries.push(build_repository_view_withdrawn_entry(
|
|
notification_uri,
|
|
&uri,
|
|
previous_hash.clone(),
|
|
));
|
|
owner_records.push(build_rrdp_uri_owner_withdrawn_record(
|
|
notification_uri,
|
|
&session_id,
|
|
expected_serial,
|
|
&uri,
|
|
previous_hash,
|
|
));
|
|
}
|
|
}
|
|
store
|
|
.put_projection_batch(&repository_view_entries, &member_records, &owner_records)
|
|
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
|
|
if let Some(index) = current_repo_index {
|
|
index
|
|
.lock()
|
|
.map_err(|_| RrdpSyncError::Storage("current repo index lock poisoned".to_string()))?
|
|
.apply_repository_view_entries(&repository_view_entries)
|
|
.map_err(RrdpSyncError::Storage)?;
|
|
}
|
|
|
|
Ok(published_count)
|
|
}
|
|
|
|
fn flush_snapshot_publish_batch(
|
|
store: &RocksStore,
|
|
notification_uri: &str,
|
|
current_repo_index: Option<&CurrentRepoIndexHandle>,
|
|
session_id: &str,
|
|
serial: u64,
|
|
published: &[(String, Vec<u8>)],
|
|
) -> Result<(), RrdpSyncError> {
|
|
let prepared_bytes = prepare_repo_bytes_batch(published).map_err(RrdpSyncError::Storage)?;
|
|
let mut repository_view_entries = Vec::with_capacity(published.len());
|
|
let mut member_records = Vec::with_capacity(published.len());
|
|
let mut owner_records = Vec::with_capacity(published.len());
|
|
|
|
for (uri, _bytes) in published {
|
|
let current_hash = prepared_bytes
|
|
.uri_to_hash
|
|
.get(uri)
|
|
.cloned()
|
|
.ok_or_else(|| {
|
|
RrdpSyncError::Storage(format!("missing raw_by_hash mapping for {uri}"))
|
|
})?;
|
|
repository_view_entries.push(build_repository_view_present_entry(
|
|
notification_uri,
|
|
uri,
|
|
¤t_hash,
|
|
));
|
|
member_records.push(build_rrdp_source_member_present_record(
|
|
notification_uri,
|
|
session_id,
|
|
serial,
|
|
uri,
|
|
¤t_hash,
|
|
));
|
|
owner_records.push(build_rrdp_uri_owner_active_record(
|
|
notification_uri,
|
|
session_id,
|
|
serial,
|
|
uri,
|
|
¤t_hash,
|
|
));
|
|
}
|
|
|
|
store
|
|
.put_blob_bytes_batch(&prepared_bytes.blobs_to_write)
|
|
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
|
|
store
|
|
.put_projection_batch(&repository_view_entries, &member_records, &owner_records)
|
|
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
|
|
if let Some(index) = current_repo_index {
|
|
index
|
|
.lock()
|
|
.map_err(|_| RrdpSyncError::Storage("current repo index lock poisoned".to_string()))?
|
|
.apply_repository_view_entries(&repository_view_entries)
|
|
.map_err(RrdpSyncError::Storage)?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
const SNAPSHOT_NON_ASCII_ERROR: &str = "snapshot body contains non-ASCII bytes";
|
|
|
|
struct SnapshotSpoolWriter<'a, W: Write> {
|
|
inner: &'a mut W,
|
|
hasher: sha2::Sha256,
|
|
bytes: u64,
|
|
}
|
|
|
|
impl<'a, W: Write> SnapshotSpoolWriter<'a, W> {
|
|
fn new(inner: &'a mut W) -> Self {
|
|
Self {
|
|
inner,
|
|
hasher: sha2::Sha256::new(),
|
|
bytes: 0,
|
|
}
|
|
}
|
|
|
|
fn finalize_hash(self) -> [u8; 32] {
|
|
let digest = self.hasher.finalize();
|
|
let mut out = [0u8; 32];
|
|
out.copy_from_slice(&digest);
|
|
out
|
|
}
|
|
}
|
|
|
|
impl<W: Write> Write for SnapshotSpoolWriter<'_, W> {
|
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
|
if buf.iter().any(|&b| b > 0x7F) {
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidData,
|
|
SNAPSHOT_NON_ASCII_ERROR,
|
|
));
|
|
}
|
|
let n = self.inner.write(buf)?;
|
|
self.hasher.update(&buf[..n]);
|
|
self.bytes += n as u64;
|
|
Ok(n)
|
|
}
|
|
|
|
fn flush(&mut self) -> std::io::Result<()> {
|
|
self.inner.flush()
|
|
}
|
|
}
|
|
|
|
pub(super) fn fetch_snapshot_into_tempfile(
|
|
fetcher: &dyn Fetcher,
|
|
snapshot_uri: &str,
|
|
expected_hash_sha256: &[u8; 32],
|
|
) -> Result<(tempfile::NamedTempFile, u64), RrdpSyncError> {
|
|
let mut tmp = tempfile::NamedTempFile::new()
|
|
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile create failed: {e}")))?;
|
|
let mut spool = SnapshotSpoolWriter::new(tmp.as_file_mut());
|
|
let bytes_written = match fetcher.fetch_to_writer(snapshot_uri, &mut spool) {
|
|
Ok(bytes) => bytes,
|
|
Err(e) if e.contains(SNAPSHOT_NON_ASCII_ERROR) => return Err(RrdpError::NotAscii.into()),
|
|
Err(e) => return Err(RrdpSyncError::Fetch(e)),
|
|
};
|
|
let computed = spool.finalize_hash();
|
|
if computed.as_slice() != expected_hash_sha256.as_slice() {
|
|
return Err(RrdpError::SnapshotHashMismatch.into());
|
|
}
|
|
tmp.as_file_mut()
|
|
.flush()
|
|
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile flush failed: {e}")))?;
|
|
tmp.as_file_mut()
|
|
.seek(SeekFrom::Start(0))
|
|
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile rewind failed: {e}")))?;
|
|
Ok((tmp, bytes_written))
|
|
}
|