1use crate::version_vector::VersionVector;
7use mdcs_merkle::{Hash, Hasher, MerkleNode, NodeBuilder, Payload};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use thiserror::Error;
11
12#[derive(Error, Debug)]
14pub enum SnapshotError {
15 #[error("Snapshot not found: {0}")]
16 NotFound(String),
17
18 #[error("Invalid snapshot data: {0}")]
19 InvalidData(String),
20
21 #[error("Serialization error: {0}")]
22 SerializationError(String),
23
24 #[error("Version mismatch: expected {expected}, got {actual}")]
25 VersionMismatch { expected: u8, actual: u8 },
26
27 #[error("Snapshot too old: {0}")]
28 TooOld(String),
29}
30
31pub const SNAPSHOT_VERSION: u8 = 1;
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
36pub struct Snapshot {
37 pub version: u8,
39
40 pub id: Hash,
42
43 pub version_vector: VersionVector,
46
47 pub superseded_roots: Vec<Hash>,
50
51 pub state_data: Vec<u8>,
53
54 pub created_at: u64,
56
57 pub creator: String,
59
60 pub metadata: HashMap<String, String>,
62}
63
64impl Snapshot {
65 pub fn new(
67 version_vector: VersionVector,
68 superseded_roots: Vec<Hash>,
69 state_data: Vec<u8>,
70 creator: impl Into<String>,
71 created_at: u64,
72 ) -> Self {
73 let creator = creator.into();
74
75 let mut hasher = Hasher::new();
77 hasher.update(&[SNAPSHOT_VERSION]);
78 hasher.update(&state_data);
79 for entry in version_vector.to_entries() {
80 hasher.update(entry.replica_id.as_bytes());
81 hasher.update(&entry.sequence.to_le_bytes());
82 }
83 hasher.update(&created_at.to_le_bytes());
84 hasher.update(creator.as_bytes());
85 let id = hasher.finalize();
86
87 Snapshot {
88 version: SNAPSHOT_VERSION,
89 id,
90 version_vector,
91 superseded_roots,
92 state_data,
93 created_at,
94 creator,
95 metadata: HashMap::new(),
96 }
97 }
98
99 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
101 self.metadata.insert(key.into(), value.into());
102 self
103 }
104
105 pub fn to_merkle_node(&self) -> Result<MerkleNode, SnapshotError> {
107 let payload_data = serde_json::to_vec(self)
108 .map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
109
110 Ok(NodeBuilder::new()
111 .with_parents(self.superseded_roots.clone())
112 .with_payload(Payload::snapshot(payload_data))
113 .with_timestamp(self.created_at)
114 .with_creator(&self.creator)
115 .build())
116 }
117
118 pub fn from_merkle_node(node: &MerkleNode) -> Result<Self, SnapshotError> {
120 match &node.payload {
121 Payload::Snapshot(data) => {
122 let snapshot: Snapshot = serde_json::from_slice(data)
123 .map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
124
125 if snapshot.version != SNAPSHOT_VERSION {
126 return Err(SnapshotError::VersionMismatch {
127 expected: SNAPSHOT_VERSION,
128 actual: snapshot.version,
129 });
130 }
131
132 Ok(snapshot)
133 }
134 _ => Err(SnapshotError::InvalidData(
135 "Node does not contain snapshot payload".to_string(),
136 )),
137 }
138 }
139
140 pub fn covers(&self, vv: &VersionVector) -> bool {
142 self.version_vector.dominates(vv)
143 }
144
145 pub fn size(&self) -> usize {
147 self.state_data.len()
148 }
149}
150
151pub struct SnapshotManager {
153 snapshots: HashMap<Hash, Snapshot>,
155
156 by_creator: HashMap<String, Vec<Hash>>,
158
159 latest: Option<Hash>,
161
162 config: SnapshotConfig,
164}
165
166#[derive(Clone, Debug)]
168pub struct SnapshotConfig {
169 pub min_operations_between: u64,
171
172 pub max_time_between: u64,
174
175 pub max_snapshots: usize,
177
178 pub auto_snapshot: bool,
180}
181
182impl Default for SnapshotConfig {
183 fn default() -> Self {
184 SnapshotConfig {
185 min_operations_between: 1000,
186 max_time_between: 10000,
187 max_snapshots: 10,
188 auto_snapshot: true,
189 }
190 }
191}
192
193impl SnapshotManager {
194 pub fn new() -> Self {
196 SnapshotManager {
197 snapshots: HashMap::new(),
198 by_creator: HashMap::new(),
199 latest: None,
200 config: SnapshotConfig::default(),
201 }
202 }
203
204 pub fn with_config(config: SnapshotConfig) -> Self {
206 SnapshotManager {
207 snapshots: HashMap::new(),
208 by_creator: HashMap::new(),
209 latest: None,
210 config,
211 }
212 }
213
214 pub fn config(&self) -> &SnapshotConfig {
216 &self.config
217 }
218
219 pub fn store(&mut self, snapshot: Snapshot) -> Hash {
221 let id = snapshot.id;
222
223 self.by_creator
224 .entry(snapshot.creator.clone())
225 .or_default()
226 .push(id);
227
228 if let Some(latest_id) = self.latest {
230 if let Some(latest) = self.snapshots.get(&latest_id) {
231 if snapshot.version_vector.dominates(&latest.version_vector) {
232 self.latest = Some(id);
233 }
234 }
235 } else {
236 self.latest = Some(id);
237 }
238
239 self.snapshots.insert(id, snapshot);
240
241 self.gc_old_snapshots();
243
244 id
245 }
246
247 pub fn get(&self, id: &Hash) -> Option<&Snapshot> {
249 self.snapshots.get(id)
250 }
251
252 pub fn latest(&self) -> Option<&Snapshot> {
254 self.latest.and_then(|id| self.snapshots.get(&id))
255 }
256
257 pub fn latest_id(&self) -> Option<Hash> {
259 self.latest
260 }
261
262 pub fn by_creator(&self, creator: &str) -> Vec<&Snapshot> {
264 self.by_creator
265 .get(creator)
266 .map(|ids| ids.iter().filter_map(|id| self.snapshots.get(id)).collect())
267 .unwrap_or_default()
268 }
269
270 pub fn find_covering(&self, vv: &VersionVector) -> Option<&Snapshot> {
272 self.snapshots
273 .values()
274 .filter(|s| s.covers(vv))
275 .max_by_key(|s| s.version_vector.total_operations())
276 }
277
278 pub fn should_snapshot(&self, current_vv: &VersionVector, current_time: u64) -> bool {
280 if !self.config.auto_snapshot {
281 return false;
282 }
283
284 match self.latest() {
285 None => true, Some(latest) => {
287 let ops_since =
288 current_vv.total_operations() - latest.version_vector.total_operations();
289 let time_since = current_time.saturating_sub(latest.created_at);
290
291 ops_since >= self.config.min_operations_between
292 || time_since >= self.config.max_time_between
293 }
294 }
295 }
296
297 fn gc_old_snapshots(&mut self) {
299 while self.snapshots.len() > self.config.max_snapshots {
300 let oldest = self
302 .snapshots
303 .iter()
304 .filter(|(id, _)| Some(**id) != self.latest)
305 .min_by_key(|(_, s)| s.created_at)
306 .map(|(id, _)| *id);
307
308 if let Some(id) = oldest {
309 if let Some(snapshot) = self.snapshots.remove(&id) {
310 if let Some(creator_snapshots) = self.by_creator.get_mut(&snapshot.creator) {
311 creator_snapshots.retain(|&sid| sid != id);
312 }
313 }
314 } else {
315 break;
316 }
317 }
318 }
319
320 pub fn stats(&self) -> SnapshotStats {
322 let total_size: usize = self.snapshots.values().map(|s| s.size()).sum();
323 let oldest = self.snapshots.values().map(|s| s.created_at).min();
324 let newest = self.snapshots.values().map(|s| s.created_at).max();
325
326 SnapshotStats {
327 count: self.snapshots.len(),
328 total_size,
329 oldest_timestamp: oldest,
330 newest_timestamp: newest,
331 }
332 }
333}
334
335impl Default for SnapshotManager {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341#[derive(Clone, Debug)]
343pub struct SnapshotStats {
344 pub count: usize,
345 pub total_size: usize,
346 pub oldest_timestamp: Option<u64>,
347 pub newest_timestamp: Option<u64>,
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_snapshot_creation() {
356 let vv = VersionVector::from_entries([("r1".to_string(), 10), ("r2".to_string(), 5)]);
357 let state_data = b"test state data".to_vec();
358 let roots = vec![Hasher::hash(b"root1")];
359
360 let snapshot = Snapshot::new(vv.clone(), roots.clone(), state_data.clone(), "r1", 100);
361
362 assert_eq!(snapshot.version, SNAPSHOT_VERSION);
363 assert_eq!(snapshot.version_vector, vv);
364 assert_eq!(snapshot.state_data, state_data);
365 assert_eq!(snapshot.created_at, 100);
366 assert_eq!(snapshot.creator, "r1");
367 }
368
369 #[test]
370 fn test_snapshot_covers() {
371 let vv1 = VersionVector::from_entries([("r1".to_string(), 10), ("r2".to_string(), 5)]);
372 let vv2 = VersionVector::from_entries([("r1".to_string(), 5), ("r2".to_string(), 3)]);
373 let vv3 = VersionVector::from_entries([("r1".to_string(), 15), ("r2".to_string(), 5)]);
374
375 let snapshot = Snapshot::new(vv1, vec![], vec![], "r1", 100);
376
377 assert!(snapshot.covers(&vv2));
378 assert!(!snapshot.covers(&vv3));
379 }
380
381 #[test]
382 fn test_snapshot_to_merkle_node() {
383 let vv = VersionVector::from_entries([("r1".to_string(), 10)]);
384 let snapshot = Snapshot::new(vv, vec![], b"data".to_vec(), "r1", 100);
385
386 let node = snapshot.to_merkle_node().unwrap();
387 assert!(matches!(node.payload, Payload::Snapshot(_)));
388
389 let recovered = Snapshot::from_merkle_node(&node).unwrap();
390 assert_eq!(recovered.id, snapshot.id);
391 assert_eq!(recovered.version_vector, snapshot.version_vector);
392 }
393
394 #[test]
395 fn test_snapshot_manager_store_and_get() {
396 let mut manager = SnapshotManager::new();
397
398 let vv = VersionVector::from_entries([("r1".to_string(), 10)]);
399 let snapshot = Snapshot::new(vv, vec![], b"data".to_vec(), "r1", 100);
400 let id = snapshot.id;
401
402 manager.store(snapshot);
403
404 assert!(manager.get(&id).is_some());
405 assert!(manager.latest().is_some());
406 assert_eq!(manager.latest_id(), Some(id));
407 }
408
409 #[test]
410 fn test_snapshot_manager_gc() {
411 let config = SnapshotConfig {
412 max_snapshots: 3,
413 ..Default::default()
414 };
415 let mut manager = SnapshotManager::with_config(config);
416
417 for i in 0..5 {
419 let vv = VersionVector::from_entries([("r1".to_string(), i as u64 + 1)]);
420 let snapshot = Snapshot::new(vv, vec![], b"data".to_vec(), "r1", i as u64);
421 manager.store(snapshot);
422 }
423
424 assert_eq!(manager.snapshots.len(), 3);
426
427 assert!(manager.latest().is_some());
429 }
430
431 #[test]
432 fn test_should_snapshot() {
433 let config = SnapshotConfig {
434 min_operations_between: 100,
435 max_time_between: 1000,
436 auto_snapshot: true,
437 ..Default::default()
438 };
439 let mut manager = SnapshotManager::with_config(config);
440
441 let vv = VersionVector::from_entries([("r1".to_string(), 10)]);
443 assert!(manager.should_snapshot(&vv, 100));
444
445 let snapshot = Snapshot::new(vv.clone(), vec![], b"data".to_vec(), "r1", 100);
447 manager.store(snapshot);
448
449 let vv2 = VersionVector::from_entries([("r1".to_string(), 50)]);
451 assert!(!manager.should_snapshot(&vv2, 200));
452
453 let vv3 = VersionVector::from_entries([("r1".to_string(), 150)]);
455 assert!(manager.should_snapshot(&vv3, 200));
456 }
457}