mdcs_core/
pncounter.rs

1//! PN-Counter (Positive-Negative Counter) CRDT
2//!
3//! A PN-Counter supports both increment and decrement operations by maintaining
4//! two separate counters: one for increments (P) and one for decrements (N).
5//! The value is P - N.
6//!
7//! Each replica has its own counter entry, and the join operation performs
8//! component-wise max across all replicas.
9
10use crate::lattice::Lattice;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeMap;
13
14/// A Positive-Negative Counter CRDT
15///
16/// Supports both increment and decrement by maintaining two separate counters.
17/// Value = sum(increments) - sum(decrements)
18#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub struct PNCounter<K: Ord + Clone> {
20    /// Per-replica increment counters
21    increments: BTreeMap<K, u64>,
22    /// Per-replica decrement counters
23    decrements: BTreeMap<K, u64>,
24}
25
26impl<K: Ord + Clone> PNCounter<K> {
27    /// Create a new PN-Counter
28    pub fn new() -> Self {
29        Self {
30            increments: BTreeMap::new(),
31            decrements: BTreeMap::new(),
32        }
33    }
34
35    /// Increment the counter for a specific replica
36    pub fn increment(&mut self, replica_id: K, amount: u64) {
37        let entry = self.increments.entry(replica_id).or_insert(0);
38        *entry = entry.saturating_add(amount);
39    }
40
41    /// Decrement the counter for a specific replica
42    pub fn decrement(&mut self, replica_id: K, amount: u64) {
43        let entry = self.decrements.entry(replica_id).or_insert(0);
44        *entry = entry.saturating_add(amount);
45    }
46
47    /// Get the current value (sum of increments - sum of decrements)
48    pub fn value(&self) -> i64 {
49        let inc_sum: u64 = self.increments.values().sum();
50        let dec_sum: u64 = self.decrements.values().sum();
51        (inc_sum as i64).saturating_sub(dec_sum as i64)
52    }
53
54    /// Get the increment counter for a replica
55    pub fn get_increment(&self, replica_id: &K) -> u64 {
56        self.increments.get(replica_id).copied().unwrap_or(0)
57    }
58
59    /// Get the decrement counter for a replica
60    pub fn get_decrement(&self, replica_id: &K) -> u64 {
61        self.decrements.get(replica_id).copied().unwrap_or(0)
62    }
63
64    /// Get a reference to all increment counters
65    pub fn increments(&self) -> &BTreeMap<K, u64> {
66        &self.increments
67    }
68
69    /// Get a reference to all decrement counters
70    pub fn decrements(&self) -> &BTreeMap<K, u64> {
71        &self.decrements
72    }
73}
74
75impl<K: Ord + Clone> Default for PNCounter<K> {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl<K: Ord + Clone> Lattice for PNCounter<K> {
82    fn bottom() -> Self {
83        Self::new()
84    }
85
86    /// Join operation performs component-wise max on both counters
87    /// This ensures that concurrent updates always converge to the same value
88    fn join(&self, other: &Self) -> Self {
89        let mut increments = self.increments.clone();
90        let mut decrements = self.decrements.clone();
91
92        // Merge counters by taking the per-replica maximum from `other` into this replica
93        for (k, v) in &other.increments {
94            increments
95                .entry(k.clone())
96                .and_modify(|e| *e = (*e).max(*v))
97                .or_insert(*v);
98        }
99
100        for (k, v) in &other.decrements {
101            decrements
102                .entry(k.clone())
103                .and_modify(|e| *e = (*e).max(*v))
104                .or_insert(*v);
105        }
106
107        Self {
108            increments,
109            decrements,
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_pncounter_basic_operations() {
120        let mut counter = PNCounter::new();
121
122        // Increment from replica "A"
123        counter.increment("A", 5);
124        assert_eq!(counter.value(), 5);
125
126        // Decrement from replica "B"
127        counter.decrement("B", 2);
128        assert_eq!(counter.value(), 3);
129
130        // Increment again
131        counter.increment("A", 3);
132        assert_eq!(counter.value(), 6);
133    }
134
135    #[test]
136    fn test_pncounter_join_idempotent() {
137        let mut c1 = PNCounter::new();
138        c1.increment("A", 5);
139        c1.decrement("B", 2);
140
141        let joined = c1.join(&c1);
142        assert_eq!(joined.value(), c1.value());
143        assert_eq!(joined.value(), 3);
144    }
145
146    #[test]
147    fn test_pncounter_join_commutative() {
148        let mut c1 = PNCounter::new();
149        c1.increment("A", 5);
150
151        let mut c2 = PNCounter::new();
152        c2.increment("B", 3);
153        c2.decrement("A", 1);
154
155        let joined1 = c1.join(&c2);
156        let joined2 = c2.join(&c1);
157
158        assert_eq!(joined1.value(), joined2.value());
159        assert_eq!(joined1.get_increment(&"A"), 5);
160        assert_eq!(joined1.get_increment(&"B"), 3);
161        assert_eq!(joined1.get_decrement(&"A"), 1);
162    }
163
164    #[test]
165    fn test_pncounter_join_associative() {
166        let mut c1 = PNCounter::new();
167        c1.increment("A", 1);
168
169        let mut c2 = PNCounter::new();
170        c2.increment("B", 2);
171
172        let mut c3 = PNCounter::new();
173        c3.decrement("C", 1);
174
175        let left = c1.join(&c2).join(&c3);
176        let right = c1.join(&c2.join(&c3));
177
178        assert_eq!(left.value(), right.value());
179    }
180
181    #[test]
182    fn test_pncounter_bottom_is_identity() {
183        let mut counter = PNCounter::new();
184        counter.increment("A", 5);
185        counter.decrement("B", 2);
186
187        let bottom = PNCounter::bottom();
188        let joined = counter.join(&bottom);
189
190        assert_eq!(joined.value(), counter.value());
191    }
192
193    #[test]
194    fn test_pncounter_convergence_different_order() {
195        let mut c1 = PNCounter::new();
196        c1.increment("X", 10);
197        c1.decrement("Y", 3);
198
199        let mut c2 = PNCounter::new();
200        c2.increment("Z", 5);
201        c2.decrement("X", 2);
202
203        // Apply updates in different order
204        let mut state1 = PNCounter::bottom();
205        state1.join_assign(&c1);
206        state1.join_assign(&c2);
207
208        let mut state2 = PNCounter::bottom();
209        state2.join_assign(&c2);
210        state2.join_assign(&c1);
211
212        assert_eq!(state1.value(), state2.value());
213    }
214
215    #[test]
216    fn test_pncounter_serialization() {
217        let mut counter = PNCounter::new();
218        counter.increment("replica1", 100);
219        counter.decrement("replica2", 25);
220
221        let serialized = serde_json::to_string(&counter).unwrap();
222        let deserialized: PNCounter<String> = serde_json::from_str(&serialized).unwrap();
223
224        assert_eq!(deserialized.value(), counter.value());
225        assert_eq!(deserialized.get_increment(&"replica1".to_string()), 100);
226        assert_eq!(deserialized.get_decrement(&"replica2".to_string()), 25);
227    }
228}