1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use once_cell::sync::Lazy;
4use sha2::{Sha256, Digest};
5use rand::prelude::SliceRandom;
6use crate::config::Config;
7use rand::RngCore;
8
9use rocksdb::{DB, Options, ColumnFamilyDescriptor};
10
11pub static PERSISTOR: Lazy<Box<dyn Persistor + Send + Sync>> = Lazy::new(|| {
12 let config = Config::new();
13 match config.database.as_str() {
14 "" => Box::new(MemoryPersistor::new()),
15 path => Box::new(DatabasePersistor::new(path)),
16 }
17});
18
19pub const SIZE: usize = 32;
21
22pub type Word = [u8; SIZE];
24
25#[allow(dead_code)]
26#[derive(Debug)]
27pub struct PersistorAccessError(pub String);
28
29pub trait Persistor {
30 fn root_list(&self) -> Vec<Word>;
31 fn root_new(&self, handle: Word, root: Word) -> Result<Word, PersistorAccessError>;
32 fn root_temp(&self, root: Word) -> Result<Word, PersistorAccessError>;
33 fn root_get(&self, handle: Word) -> Result<Word, PersistorAccessError> ;
34 fn root_set(&self, handle: Word, old: Word, new: Word) -> Result<Word, PersistorAccessError>;
35 fn root_delete(&self, handle: Word) -> Result<(), PersistorAccessError>;
36 fn branch_set(&self, left: Word, right: Word) -> Result<Word, PersistorAccessError>;
37 fn branch_get(&self, branch: Word) -> Result<(Word, Word), PersistorAccessError>;
38 fn leaf_set(&self, content: Vec<u8>) -> Result<Word, PersistorAccessError>;
39 fn leaf_get(&self, leaf: Word) -> Result<Vec<u8>, PersistorAccessError>;
40}
41
42#[derive(Clone)]
43pub struct MemoryPersistor {
44 roots: Arc<Mutex<HashMap<Word, (Word, bool)>>>,
45 branches: Arc<Mutex<HashMap<Word, (Word, Word)>>>,
46 leaves: Arc<Mutex<HashMap<Word, Vec<u8>>>>,
47 references: Arc<Mutex<HashMap<Word, usize>>>,
48}
49
50impl MemoryPersistor {
51 pub fn new() -> Self {
52 Self {
53 roots: Arc::new(Mutex::new(HashMap::new())),
54 branches: Arc::new(Mutex::new(HashMap::new())),
55 leaves: Arc::new(Mutex::new(HashMap::new())),
56 references: Arc::new(Mutex::new(HashMap::new())),
57 }
58 }
59
60 fn reference_increment(&self, node: Word) {
61 let mut references = self.references.lock().unwrap();
62 match references.get(&node) {
63 Some(count) => {
64 let count_ = *count;
65 references.insert(node, count_ + 1);
66 },
67 None => { references.insert(node, 1); },
68 };
69 }
70
71 fn reference_decrement(&self, node: Word) {
72 let mut references = self.references.lock().unwrap();
73 match references.get(&node) {
74 Some(count_old) => {
75 let count_new = *count_old - 1;
76 if count_new > 0 {
77 references.insert(node, count_new);
78 } else {
79 references.remove(&node);
80 let mut branches = self.branches.lock().unwrap();
81 if let Some((left, right)) = branches.get(&node) {
82 let left_ = *left;
83 let right_ = *right;
84 branches.remove(&node);
85 drop(references);
86 drop(branches);
87 self.reference_decrement(left_);
88 self.reference_decrement(right_);
89 } else {
90 let mut leaves = self.leaves.lock().unwrap();
91 if let Some(_) = leaves.get(&node) {
92 leaves.remove(&node);
93 }
94 }
95 }
96 },
97 None => {},
98 };
99 }
100}
101
102impl Persistor for MemoryPersistor {
103 fn root_list(&self) -> Vec<Word> {
104 let mut keys: Vec<Word> = self.roots.lock().unwrap().iter()
105 .filter(|&(_, &(_, is_persistent))| is_persistent)
106 .map(|(key, _)| key)
107 .cloned()
108 .collect();
109 keys.sort();
110 keys
111 }
112
113 fn root_new(&self, handle: Word, root: Word) -> Result<Word, PersistorAccessError> {
114 let mut roots = self.roots.lock().unwrap();
115 match roots.get(&handle) {
116 Some(_) => Err(PersistorAccessError(format!("Handle {:?} already exists", handle))),
117 None => {
118 self.reference_increment(root);
119 roots.insert(handle, (root, true));
120 Ok(handle)
121 },
122 }
123 }
124
125 fn root_temp(&self, root: Word) -> Result<Word, PersistorAccessError> {
126 let mut roots = self.roots.lock().unwrap();
127 let mut handle: Word = [0 as u8; 32];
128 rand::thread_rng().fill_bytes(&mut handle);
129 match roots.get(&handle) {
130 Some(_) => Err(PersistorAccessError(format!("Handle {:?} already exists", handle))),
131 None => {
132 self.reference_increment(root);
133 roots.insert(handle, (root, false));
134 Ok(handle)
135 },
136 }
137 }
138
139 fn root_get(&self, handle: Word) -> Result<Word, PersistorAccessError> {
140 match self.roots.lock().unwrap().get(&handle) {
141 Some((root, _)) => Ok(*root),
142 None => Err(PersistorAccessError(format!("Handle {:?} not found", handle))),
143 }
144 }
145
146 fn root_set(&self, handle: Word, old: Word, new: Word) -> Result<Word, PersistorAccessError> {
147 let mut roots = self.roots.lock().unwrap();
148 match roots.get(&handle) {
149 Some((root, true)) if *root == old => {
150 self.reference_increment(new);
151 self.reference_decrement(old);
152 roots.insert(handle, (new, true));
153 Ok(handle)
154 },
155 Some((_, false)) => Err(PersistorAccessError(format!("Handle {:?} is temporary", handle))),
156 Some((_, true)) => Err(PersistorAccessError(format!("Handle {:?} changed since compare", handle))),
157 None => Err(PersistorAccessError(format!("Handle {:?} not found", handle))),
158 }
159 }
160
161 fn root_delete(&self, handle: Word) -> Result<(), PersistorAccessError> {
162 let mut roots = self.roots.lock().unwrap();
163 match roots.get(&handle) {
164 Some((old, _)) => {
165 self.reference_decrement(*old);
166 roots.remove(&handle);
167 Ok(())
168 },
169 None => Err(PersistorAccessError(format!("Handle {:?} not found", handle))),
170 }
171 }
172
173 fn branch_set(&self, left: Word, right: Word) -> Result<Word, PersistorAccessError> {
174 let mut joined = [0 as u8; SIZE * 2];
175 joined[..SIZE].copy_from_slice(&left);
176 joined[SIZE..].copy_from_slice(&right);
177
178 let branch = Sha256::digest(joined);
179 self.branches.lock().unwrap().insert(branch.into(), (left, right));
180 self.reference_increment(left);
181 self.reference_increment(right);
182 Ok(Word::from(branch))
183 }
184
185 fn branch_get(&self, branch: Word) -> Result<(Word, Word), PersistorAccessError> {
186 let branches = self.branches.lock().unwrap();
187 match branches.get(&branch) {
188 Some((left, right)) => {
189 let mut joined = [0 as u8; SIZE * 2];
190 joined[..SIZE].copy_from_slice(left);
191 joined[SIZE..].copy_from_slice(right);
192 assert!(Vec::from(branch) == Sha256::digest(joined).to_vec());
193 Ok((*left, *right))
194 },
195 None => Err(PersistorAccessError(format!("Branch {:?} not found", branch))),
196 }
197 }
198
199 fn leaf_set(&self, content: Vec<u8>) -> Result<Word, PersistorAccessError> {
200 let leaf = Word::from(Sha256::digest(Sha256::digest(&content)));
201 self.leaves.lock().unwrap().insert(leaf, content);
202 Ok(leaf)
203 }
204
205 fn leaf_get(&self, leaf: Word) -> Result<Vec<u8>, PersistorAccessError> {
206 let leaves = self.leaves.lock().unwrap();
207 match leaves.get(&leaf) {
208 Some(content) => {
209 assert!(Vec::from(leaf) == Sha256::digest(Sha256::digest(content)).to_vec());
210 Ok(content.to_vec())
211 }
212 None => Err(PersistorAccessError(format!("Leaf {:?} not found", leaf))),
213 }
214 }
215}
216
217pub struct DatabasePersistor {
218 db: Mutex<DB>,
219}
220
221impl DatabasePersistor {
222 pub fn new(path: &str) -> Self {
223
224 let mut opts = Options::default();
225 opts.create_if_missing(true);
226 opts.create_missing_column_families(true);
227
228 let cfs = vec![
229 ColumnFamilyDescriptor::new("roots", Options::default()),
230 ColumnFamilyDescriptor::new("branches", Options::default()),
231 ColumnFamilyDescriptor::new("leaves", Options::default()),
232 ColumnFamilyDescriptor::new("references", Options::default()),
233 ];
234
235 let persistor = Self {
236 db: Mutex::new(DB::open_cf_descriptors(&opts, path, cfs).unwrap()),
237 };
238
239 {
240 let mut handles: Vec<Word> = Vec::new();
241 let db = persistor.db.lock().unwrap();
242 let mut iter = db.raw_iterator_cf(db.cf_handle("roots").unwrap());
243 iter.seek_to_first();
244 while iter.valid() {
245 if (*iter.value().unwrap())[SIZE] == false as u8 {
246 handles.push((*iter.key().unwrap()).try_into().unwrap());
247 }
248 iter.next();
249 }
250 for handle in handles {
251 db.delete_cf(db.cf_handle("roots").unwrap(), handle).unwrap();
252 }
253 }
254
255 persistor
256 }
257
258 fn reference_increment(&self, node: Word) {
259 let db = self.db.lock().unwrap();
260 let references = db.cf_handle("references").unwrap();
261 match db.get_cf(references, node) {
262 Ok(Some(count)) => {
263 db.put_cf(
264 references,
265 node,
266 (usize::from_ne_bytes(count.try_into().unwrap()) + 1).to_ne_bytes(),
267 ).unwrap();
268 },
269 Ok(None) => { db.put_cf(references, node, (1 as usize).to_ne_bytes()).unwrap() },
270 Err(e) => { panic!{"{}", e} },
271 };
272 }
273
274 fn reference_decrement(&self, node: Word) {
275 let db = self.db.lock().unwrap();
276 let branches = db.cf_handle("branches").unwrap();
277 let leaves = db.cf_handle("leaves").unwrap();
278 let references = db.cf_handle("references").unwrap();
279 match db.get_cf(references, node).unwrap() {
280 Some(count_old) => {
281 let count_new = usize::from_ne_bytes(count_old.try_into().unwrap()) - 1;
282 if count_new > 0 {
283 db.put_cf(references, node, count_new.to_ne_bytes()).unwrap();
284 } else {
285 db.delete_cf(references, node).unwrap();
286 if let Some(value) = db.get_cf(branches, node).unwrap() {
287 let left = &value[..SIZE].try_into().unwrap();
288 let right = &value[SIZE..].try_into().unwrap();
289 db.delete_cf(branches, node).unwrap();
290 drop(db);
291 self.reference_decrement(*left);
292 self.reference_decrement(*right);
293 } else {
294 if let Some(_) = db.get_cf(leaves, node).unwrap() {
295 db.delete_cf(leaves, node).unwrap();
296 }
297 }
298 }
299 },
300 None => {},
301 };
302 }
303}
304
305impl Persistor for DatabasePersistor {
306 fn root_list(&self) -> Vec<Word> {
307 let mut handles: Vec<Word> = Vec::new();
308 let db = self.db.lock().unwrap();
309 let roots = db.cf_handle("roots").unwrap();
310 let mut iter = db.raw_iterator_cf(roots);
311 iter.seek_to_first();
312 while iter.valid() {
313 if (*iter.value().unwrap())[SIZE] != false as u8 {
314 handles.push((*iter.key().unwrap()).try_into().unwrap());
315 }
316 iter.next();
317 }
318
319 handles.shuffle(&mut rand::thread_rng());
320 handles
321 }
322
323 fn root_new(&self, handle: Word, root: Word) -> Result<Word, PersistorAccessError> {
324 let mut root_marked = [0 as u8; SIZE + 1];
325 root_marked[..SIZE].copy_from_slice(&root);
326 root_marked[SIZE] = true as u8;
327
328 let db = self.db.lock().unwrap();
329 let roots = db.cf_handle("roots").unwrap();
330 match db.get_cf(roots, handle) {
331 Ok(Some(_)) => Err(PersistorAccessError(format!("Handle {:?} already exists", handle))),
332 Ok(None) => {
333 db.put_cf(roots, handle, root_marked).unwrap();
334 drop(db);
335 self.reference_increment(root);
336 Ok(handle)
337 },
338 Err(e) => Err(PersistorAccessError(format!("{}", e))),
339 }
340 }
341
342 fn root_temp(&self, root: Word) -> Result<Word, PersistorAccessError> {
343 let mut root_marked = [0 as u8; SIZE + 1];
344 root_marked[..SIZE].copy_from_slice(&root);
345 root_marked[SIZE] = false as u8;
346
347 let mut handle: Word = [0 as u8; 32];
348 rand::thread_rng().fill_bytes(&mut handle);
349 let db = self.db.lock().unwrap();
350 let roots = db.cf_handle("roots").unwrap();
351 match db.get_cf(roots, handle) {
352 Ok(Some(_)) => Err(PersistorAccessError(format!("Handle {:?} already exists", handle))),
353 Ok(None) => {
354 db.put_cf(roots, handle, root_marked).unwrap();
355 drop(db);
356 self.reference_increment(root);
357 Ok(handle)
358 },
359 Err(e) => Err(PersistorAccessError(format!("{}", e))),
360 }
361 }
362
363 fn root_get(&self, handle: Word) -> Result<Word, PersistorAccessError> {
364 let db = self.db.lock().unwrap();
365 let roots = db.cf_handle("roots").unwrap();
366 match db.get_cf(roots, handle) {
367 Ok(Some(root_marked)) => Ok(((*root_marked)[..SIZE]).try_into().unwrap()),
368 Ok(None) => Err(PersistorAccessError(format!("Handle {:?} not found", handle))),
369 Err(e) => Err(PersistorAccessError(format!("{}", e))),
370 }
371 }
372
373 fn root_set(&self, handle: Word, old: Word, new: Word) -> Result<Word, PersistorAccessError> {
374 let db = self.db.lock().unwrap();
375 let roots = db.cf_handle("roots").unwrap();
376 match db.get_cf(roots, handle) {
377 Ok(Some(root_marked)) => match root_marked[SIZE] != false as u8 {
378 true => match root_marked[..SIZE] == old.to_vec() {
379 true => {
380 let mut new_marked = [0 as u8; SIZE + 1];
381 new_marked[..SIZE].copy_from_slice(&new);
382 new_marked[SIZE] = true as u8;
383 db.put_cf(roots, handle, new_marked).unwrap();
384 drop(db);
385 self.reference_increment(new);
386 self.reference_decrement(old);
387 Ok(handle)
388 },
389 false => Err(PersistorAccessError(format!("Handle {:?} changed since compare", handle))),
390 },
391 false => Err(PersistorAccessError(format!("Handle {:?} is temporary", handle))),
392 },
393 Ok(None) => Err(PersistorAccessError(format!("Handle {:?} not found", handle))),
394 Err(e) => Err(PersistorAccessError(format!("{}", e))),
395 }
396 }
397
398 fn root_delete(&self, handle: Word) -> Result<(), PersistorAccessError> {
399 let db = self.db.lock().unwrap();
400 let roots = db.cf_handle("roots").unwrap();
401 match db.get_cf(roots, handle) {
402 Ok(Some(root_marked)) => {
403 db.delete_cf(roots, handle).unwrap();
404 drop(db);
405 self.reference_decrement(root_marked[..SIZE].try_into().unwrap());
406 Ok(())
407 },
408 Ok(None) => Err(PersistorAccessError(format!("Handle {:?} not found", handle))),
409 Err(e) => Err(PersistorAccessError(format!("{}", e))),
410 }
411 }
412
413 fn branch_set(&self, left: Word, right: Word) -> Result<Word, PersistorAccessError> {
414 let mut joined = [0 as u8; SIZE * 2];
415 joined[..SIZE].copy_from_slice(&left);
416 joined[SIZE..].copy_from_slice(&right);
417
418 let branch = Sha256::digest(joined);
419
420 let db = self.db.lock().unwrap();
421 let branches = db.cf_handle("branches").unwrap();
422 db.put_cf(branches, branch, joined).unwrap();
423 drop(db);
424 self.reference_increment(left);
425 self.reference_increment(right);
426
427 Ok(Word::from(branch))
428 }
429
430 fn branch_get(&self, branch: Word) -> Result<(Word, Word), PersistorAccessError> {
431 let db = self.db.lock().unwrap();
432 let branches = db.cf_handle("branches").unwrap();
433 match db.get_cf(branches, branch) {
434 Ok(Some(value)) => {
435 assert!(Vec::from(branch) == Sha256::digest(value.clone()).to_vec());
436 let left = &value[..SIZE].try_into().unwrap();
437 let right = &value[SIZE..].try_into().unwrap();
438 Ok((*left, *right))
439 },
440 Ok(None) => Err(PersistorAccessError(format!("Branch {:?} not found", branch))),
441 Err(e) => Err(PersistorAccessError(format!("{}", e))),
442 }
443 }
444
445 fn leaf_set(&self, content: Vec<u8>) -> Result<Word, PersistorAccessError> {
446 let leaf = Word::from(Sha256::digest(Sha256::digest(&content)));
447 let db = self.db.lock().unwrap();
448 let leaves = db.cf_handle("leaves").unwrap();
449 db.put_cf(leaves, leaf, content.clone()).unwrap();
450 Ok(leaf)
451 }
452
453 fn leaf_get(&self, leaf: Word) -> Result<Vec<u8>, PersistorAccessError> {
454 let db = self.db.lock().unwrap();
455 let leaves = db.cf_handle("leaves").unwrap();
456 match db.get_cf(leaves, leaf) {
457 Ok(Some(content)) => {
458 assert!(leaf == *Sha256::digest(Sha256::digest(content.clone())));
459 Ok(content.to_vec())
460 },
461 Ok(None) => Err(PersistorAccessError(format!("Leaf {:?} not found", leaf))),
462 Err(e) => Err(PersistorAccessError(format!("{}", e))),
463 }
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use std::fs;
470 use super::{Word, SIZE, Persistor, MemoryPersistor, DatabasePersistor};
471 use rocksdb::{DB, IteratorMode};
472 use std::sync::Mutex;
473
474 fn test_persistence(persistor: Box<dyn Persistor>) {
475 let zeros: Word = [0 as u8; SIZE];
476
477 assert!(
478 persistor.root_delete(
479 persistor.root_temp(
480 zeros,
481 ).unwrap(),
482 ).unwrap() == ()
483 );
484
485 assert!(
486 persistor.root_get(
487 persistor.root_set(
488 persistor.root_new(
489 zeros,
490 zeros,
491 ).unwrap(),
492 zeros,
493 zeros,
494 ).unwrap(),
495 ).unwrap() == zeros
496 );
497
498 assert!(
499 persistor.branch_get(
500 persistor.branch_set(
501 zeros,
502 zeros,
503 ).unwrap(),
504 ).unwrap() == (zeros, zeros)
505 );
506
507 assert!(
508 persistor.leaf_get(
509 persistor.leaf_set(
510 vec!(0),
511 ).unwrap(),
512 ).unwrap() == vec!(0)
513 );
514 }
515
516 #[test]
517 fn test_memory_persistence() {
518 test_persistence(Box::new(MemoryPersistor::new()));
519 }
520
521 #[test]
522 fn test_database_persistence() {
523 let db = ".test-database-persistence";
524 let _ = fs::remove_dir_all(db);
525 test_persistence(Box::new(DatabasePersistor::new(db)));
526 let _ = fs::remove_dir_all(db);
527 }
528
529 #[test]
530 fn test_memory_garbage() {
531 let persistor = MemoryPersistor::new();
532
533 let handle: Word = [0 as u8; SIZE];
534 let leaf_0 = persistor.leaf_set(vec![0]).unwrap();
535 let leaf_1 = persistor.leaf_set(vec![1]).unwrap();
536 let leaf_2 = persistor.leaf_set(vec![2]).unwrap();
537
538 let branch_a = persistor.branch_set(leaf_0, leaf_1).unwrap();
539 let branch_b = persistor.branch_set(branch_a, leaf_2).unwrap();
540
541 persistor.root_new(handle, branch_b).unwrap();
542
543 assert!(persistor.roots.lock().unwrap().len() == 1);
544 assert!(persistor.branches.lock().unwrap().len() == 2);
545 assert!(persistor.leaves.lock().unwrap().len() == 3);
546 assert!(persistor.references.lock().unwrap().len() == 5);
547
548 let leaf_3 = persistor.leaf_set(vec![3]).unwrap();
549 let branch_c = persistor.branch_set(leaf_2, leaf_3).unwrap();
550 persistor.root_set(handle, branch_b, branch_c).unwrap();
551
552 assert!(persistor.roots.lock().unwrap().len() == 1);
553 assert!(persistor.branches.lock().unwrap().len() == 1);
554 assert!(persistor.leaves.lock().unwrap().len() == 2);
555 assert!(persistor.references.lock().unwrap().len() == 3);
556 }
557
558 #[test]
559 fn test_database_garbage() {
560 let db = ".test-database-garbage";
561 let _ = fs::remove_dir_all(db);
562 let persistor = DatabasePersistor::new(db);
563
564 let handle: Word = [0 as u8; SIZE];
565 let leaf_0 = persistor.leaf_set(vec![0]).unwrap();
566 let leaf_1 = persistor.leaf_set(vec![1]).unwrap();
567 let leaf_2 = persistor.leaf_set(vec![2]).unwrap();
568
569 let branch_a = persistor.branch_set(leaf_0, leaf_1).unwrap();
570 let branch_b = persistor.branch_set(branch_a, leaf_2).unwrap();
571
572 persistor.root_new(handle, branch_b).unwrap();
573
574 let cf_count = | mdb: &Mutex<DB>, cf | {
575 let db_ = mdb.lock().unwrap();
576 db_.iterator_cf(db_.cf_handle(cf).unwrap(), IteratorMode::Start).count()
577 };
578
579 {
580 assert!(cf_count(&persistor.db, "roots") == 1);
581 assert!(cf_count(&persistor.db, "branches") == 2);
582 assert!(cf_count(&persistor.db, "leaves") == 3);
583 assert!(cf_count(&persistor.db, "references") == 5);
584 }
585
586 let leaf_3 = persistor.leaf_set(vec![3]).unwrap();
587 let branch_c = persistor.branch_set(leaf_2, leaf_3).unwrap();
588 persistor.root_set(handle, branch_b, branch_c).unwrap();
589
590 {
591 assert!(cf_count(&persistor.db, "roots") == 1);
592 assert!(cf_count(&persistor.db, "branches") == 1);
593 assert!(cf_count(&persistor.db, "leaves") == 2);
594 assert!(cf_count(&persistor.db, "references") == 3);
595 }
596
597 let _ = fs::remove_dir_all(db);
598 }
599}