rpki/src/sync/rrdp/snapshot_apply.rs

446 lines
17 KiB
Rust

use base64::Engine;
use quick_xml::Reader;
use quick_xml::events::Event;
use sha2::Digest;
use std::io::{BufRead, Seek, SeekFrom, Write};
use uuid::Uuid;
use crate::current_repo_index::CurrentRepoIndexHandle;
use crate::storage::RocksStore;
use crate::sync::store_projection::{
build_repository_view_present_entry, build_repository_view_withdrawn_entry,
build_rrdp_source_member_present_record, build_rrdp_source_member_withdrawn_record,
build_rrdp_uri_owner_active_record, build_rrdp_uri_owner_withdrawn_record, compute_sha256_hex,
current_rrdp_owner_is, ensure_rrdp_uri_can_be_owned_by, prepare_repo_bytes_batch,
};
use super::{
Fetcher, RRDP_SNAPSHOT_APPLY_BATCH_SIZE, RRDP_XMLNS, RrdpError, RrdpSyncError, parse_u64_str,
strip_all_ascii_whitespace,
};
#[cfg(test)]
pub(super) fn apply_snapshot(
store: &RocksStore,
notification_uri: &str,
current_repo_index: Option<&CurrentRepoIndexHandle>,
snapshot_xml: &[u8],
expected_session_id: Uuid,
expected_serial: u64,
) -> Result<usize, RrdpSyncError> {
if snapshot_xml.iter().any(|&b| b > 0x7F) {
return Err(RrdpError::NotAscii.into());
}
apply_snapshot_from_bufread(
store,
notification_uri,
current_repo_index,
std::io::Cursor::new(snapshot_xml),
expected_session_id,
expected_serial,
)
}
pub(super) fn apply_snapshot_from_bufread<R: BufRead>(
store: &RocksStore,
notification_uri: &str,
current_repo_index: Option<&CurrentRepoIndexHandle>,
input: R,
expected_session_id: Uuid,
expected_serial: u64,
) -> Result<usize, RrdpSyncError> {
let previous_members: Vec<String> = store
.list_current_rrdp_source_members(notification_uri)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?
.into_iter()
.map(|record| record.rsync_uri)
.collect();
let session_id = expected_session_id.to_string();
let mut new_set: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut batch_published: Vec<(String, Vec<u8>)> =
Vec::with_capacity(RRDP_SNAPSHOT_APPLY_BATCH_SIZE);
let mut published_count = 0usize;
let mut reader = Reader::from_reader(input);
reader.config_mut().trim_text(false);
let mut buf = Vec::new();
let mut root_seen = false;
let mut in_publish = false;
let mut publish_nested_depth = 0usize;
let mut current_publish_uri: Option<String> = None;
let mut current_publish_text = String::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(e)) => {
let local_name = e.local_name();
let local_name = local_name.as_ref();
if !root_seen {
root_seen = true;
if local_name != b"snapshot" {
let got = String::from_utf8_lossy(local_name).to_string();
return Err(RrdpError::UnexpectedRoot(got).into());
}
let mut xmlns = String::new();
let mut version = String::new();
let mut session_id_attr = String::new();
let mut serial_attr = String::new();
for attr in e.attributes().with_checks(false) {
let attr = match attr {
Ok(attr) => attr,
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
};
let key = attr.key.as_ref();
let value = attr
.decode_and_unescape_value(reader.decoder())
.map_err(|e| RrdpError::Xml(e.to_string()))?
.into_owned();
match key {
b"xmlns" => xmlns = value,
b"version" => version = value,
b"session_id" => session_id_attr = value,
b"serial" => serial_attr = value,
_ => {}
}
}
if xmlns != RRDP_XMLNS {
return Err(RrdpError::InvalidNamespace(xmlns).into());
}
if version != "1" {
return Err(RrdpError::InvalidVersion(version).into());
}
let got_session_id = Uuid::parse_str(&session_id_attr)
.map_err(|_| RrdpError::InvalidSessionId(session_id_attr.clone()))?;
if got_session_id != expected_session_id {
return Err(RrdpError::SnapshotSessionIdMismatch {
expected: expected_session_id.to_string(),
got: got_session_id.to_string(),
}
.into());
}
let got_serial = parse_u64_str(&serial_attr)?;
if got_serial != expected_serial {
return Err(RrdpError::SnapshotSerialMismatch {
expected: expected_serial,
got: got_serial,
}
.into());
}
} else if in_publish {
publish_nested_depth += 1;
} else if local_name == b"publish" {
let mut uri = None;
for attr in e.attributes().with_checks(false) {
let attr = match attr {
Ok(attr) => attr,
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
};
if attr.key.as_ref() == b"uri" {
uri = Some(
attr.decode_and_unescape_value(reader.decoder())
.map_err(|e| RrdpError::Xml(e.to_string()))?
.into_owned(),
);
}
}
let uri = uri.ok_or(RrdpError::PublishUriMissing)?;
ensure_rrdp_uri_can_be_owned_by(store, notification_uri, &uri)
.map_err(RrdpSyncError::Storage)?;
in_publish = true;
publish_nested_depth = 0;
current_publish_uri = Some(uri);
current_publish_text.clear();
}
}
Ok(Event::Empty(e)) => {
let local_name = e.local_name();
let local_name = local_name.as_ref();
if !root_seen {
let got = String::from_utf8_lossy(local_name).to_string();
return Err(RrdpError::UnexpectedRoot(got).into());
}
if local_name == b"publish" {
let mut has_uri = false;
for attr in e.attributes().with_checks(false) {
let attr = match attr {
Ok(attr) => attr,
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
};
if attr.key.as_ref() == b"uri" {
has_uri = true;
break;
}
}
if !has_uri {
return Err(RrdpError::PublishUriMissing.into());
}
return Err(RrdpError::PublishContentMissing.into());
}
}
Ok(Event::Text(e)) => {
if in_publish && publish_nested_depth == 0 {
let text = reader
.decoder()
.decode(e.as_ref())
.map_err(|e| RrdpError::Xml(e.to_string()))?;
current_publish_text.push_str(&text);
}
}
Ok(Event::CData(e)) => {
if in_publish && publish_nested_depth == 0 {
let text = reader
.decoder()
.decode(e.as_ref())
.map_err(|e| RrdpError::Xml(e.to_string()))?;
current_publish_text.push_str(&text);
}
}
Ok(Event::End(e)) => {
let local_name = e.local_name();
let local_name = local_name.as_ref();
if in_publish {
if publish_nested_depth > 0 {
publish_nested_depth -= 1;
} else if local_name == b"publish" {
let uri = current_publish_uri
.take()
.ok_or_else(|| RrdpError::Xml("publish uri missing in state".into()))?;
let content_b64 = strip_all_ascii_whitespace(&current_publish_text);
current_publish_text.clear();
if content_b64.is_empty() {
return Err(RrdpError::PublishContentMissing.into());
}
let bytes = base64::engine::general_purpose::STANDARD
.decode(content_b64.as_bytes())
.map_err(|e| RrdpError::PublishBase64(e.to_string()))?;
new_set.insert(uri.clone());
batch_published.push((uri, bytes));
published_count += 1;
if batch_published.len() >= RRDP_SNAPSHOT_APPLY_BATCH_SIZE {
flush_snapshot_publish_batch(
store,
notification_uri,
current_repo_index,
&session_id,
expected_serial,
&batch_published,
)?;
batch_published.clear();
}
in_publish = false;
}
}
}
Ok(Event::Eof) => break,
Ok(Event::Decl(_) | Event::PI(_) | Event::Comment(_) | Event::DocType(_)) => {}
Err(e) => return Err(RrdpError::Xml(e.to_string()).into()),
}
buf.clear();
}
if !root_seen {
return Err(RrdpError::Xml("missing root element".to_string()).into());
}
if in_publish {
return Err(RrdpError::PublishContentMissing.into());
}
if !batch_published.is_empty() {
flush_snapshot_publish_batch(
store,
notification_uri,
current_repo_index,
&session_id,
expected_serial,
&batch_published,
)?;
batch_published.clear();
}
let mut withdrawn: Vec<(String, Option<String>)> = Vec::new();
for old_uri in &previous_members {
if new_set.contains(old_uri) {
continue;
}
let previous_hash = store
.get_repository_view_entry(old_uri)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?
.and_then(|entry| entry.current_hash)
.or_else(|| {
store
.load_current_object_bytes_by_uri(old_uri)
.ok()
.flatten()
.map(|bytes| compute_sha256_hex(&bytes))
});
withdrawn.push((old_uri.clone(), previous_hash));
}
let mut repository_view_entries = Vec::with_capacity(withdrawn.len());
let mut member_records = Vec::with_capacity(withdrawn.len());
let mut owner_records = Vec::with_capacity(withdrawn.len());
for (uri, previous_hash) in withdrawn {
member_records.push(build_rrdp_source_member_withdrawn_record(
notification_uri,
&session_id,
expected_serial,
&uri,
previous_hash.clone(),
));
if current_rrdp_owner_is(store, notification_uri, &uri).map_err(RrdpSyncError::Storage)? {
repository_view_entries.push(build_repository_view_withdrawn_entry(
notification_uri,
&uri,
previous_hash.clone(),
));
owner_records.push(build_rrdp_uri_owner_withdrawn_record(
notification_uri,
&session_id,
expected_serial,
&uri,
previous_hash,
));
}
}
store
.put_projection_batch(&repository_view_entries, &member_records, &owner_records)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
if let Some(index) = current_repo_index {
index
.lock()
.map_err(|_| RrdpSyncError::Storage("current repo index lock poisoned".to_string()))?
.apply_repository_view_entries(&repository_view_entries)
.map_err(RrdpSyncError::Storage)?;
}
Ok(published_count)
}
fn flush_snapshot_publish_batch(
store: &RocksStore,
notification_uri: &str,
current_repo_index: Option<&CurrentRepoIndexHandle>,
session_id: &str,
serial: u64,
published: &[(String, Vec<u8>)],
) -> Result<(), RrdpSyncError> {
let prepared_bytes = prepare_repo_bytes_batch(published).map_err(RrdpSyncError::Storage)?;
let mut repository_view_entries = Vec::with_capacity(published.len());
let mut member_records = Vec::with_capacity(published.len());
let mut owner_records = Vec::with_capacity(published.len());
for (uri, _bytes) in published {
let current_hash = prepared_bytes
.uri_to_hash
.get(uri)
.cloned()
.ok_or_else(|| {
RrdpSyncError::Storage(format!("missing raw_by_hash mapping for {uri}"))
})?;
repository_view_entries.push(build_repository_view_present_entry(
notification_uri,
uri,
&current_hash,
));
member_records.push(build_rrdp_source_member_present_record(
notification_uri,
session_id,
serial,
uri,
&current_hash,
));
owner_records.push(build_rrdp_uri_owner_active_record(
notification_uri,
session_id,
serial,
uri,
&current_hash,
));
}
store
.put_blob_bytes_batch(&prepared_bytes.blobs_to_write)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
store
.put_projection_batch(&repository_view_entries, &member_records, &owner_records)
.map_err(|e| RrdpSyncError::Storage(e.to_string()))?;
if let Some(index) = current_repo_index {
index
.lock()
.map_err(|_| RrdpSyncError::Storage("current repo index lock poisoned".to_string()))?
.apply_repository_view_entries(&repository_view_entries)
.map_err(RrdpSyncError::Storage)?;
}
Ok(())
}
const SNAPSHOT_NON_ASCII_ERROR: &str = "snapshot body contains non-ASCII bytes";
struct SnapshotSpoolWriter<'a, W: Write> {
inner: &'a mut W,
hasher: sha2::Sha256,
bytes: u64,
}
impl<'a, W: Write> SnapshotSpoolWriter<'a, W> {
fn new(inner: &'a mut W) -> Self {
Self {
inner,
hasher: sha2::Sha256::new(),
bytes: 0,
}
}
fn finalize_hash(self) -> [u8; 32] {
let digest = self.hasher.finalize();
let mut out = [0u8; 32];
out.copy_from_slice(&digest);
out
}
}
impl<W: Write> Write for SnapshotSpoolWriter<'_, W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.iter().any(|&b| b > 0x7F) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
SNAPSHOT_NON_ASCII_ERROR,
));
}
let n = self.inner.write(buf)?;
self.hasher.update(&buf[..n]);
self.bytes += n as u64;
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
pub(super) fn fetch_snapshot_into_tempfile(
fetcher: &dyn Fetcher,
snapshot_uri: &str,
expected_hash_sha256: &[u8; 32],
) -> Result<(tempfile::NamedTempFile, u64), RrdpSyncError> {
let mut tmp = tempfile::NamedTempFile::new()
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile create failed: {e}")))?;
let mut spool = SnapshotSpoolWriter::new(tmp.as_file_mut());
let bytes_written = match fetcher.fetch_to_writer(snapshot_uri, &mut spool) {
Ok(bytes) => bytes,
Err(e) if e.contains(SNAPSHOT_NON_ASCII_ERROR) => return Err(RrdpError::NotAscii.into()),
Err(e) => return Err(RrdpSyncError::Fetch(e)),
};
let computed = spool.finalize_hash();
if computed.as_slice() != expected_hash_sha256.as_slice() {
return Err(RrdpError::SnapshotHashMismatch.into());
}
tmp.as_file_mut()
.flush()
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile flush failed: {e}")))?;
tmp.as_file_mut()
.seek(SeekFrom::Start(0))
.map_err(|e| RrdpSyncError::Fetch(format!("tempfile rewind failed: {e}")))?;
Ok((tmp, bytes_written))
}