1use crate::lattice::Lattice;
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeMap;
13
14#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
19pub struct PNCounter<K: Ord + Clone> {
20 increments: BTreeMap<K, u64>,
22 decrements: BTreeMap<K, u64>,
24}
25
26impl<K: Ord + Clone> PNCounter<K> {
27 pub fn new() -> Self {
29 Self {
30 increments: BTreeMap::new(),
31 decrements: BTreeMap::new(),
32 }
33 }
34
35 pub fn increment(&mut self, replica_id: K, amount: u64) {
37 let entry = self.increments.entry(replica_id).or_insert(0);
38 *entry = entry.saturating_add(amount);
39 }
40
41 pub fn decrement(&mut self, replica_id: K, amount: u64) {
43 let entry = self.decrements.entry(replica_id).or_insert(0);
44 *entry = entry.saturating_add(amount);
45 }
46
47 pub fn value(&self) -> i64 {
49 let inc_sum: u64 = self.increments.values().sum();
50 let dec_sum: u64 = self.decrements.values().sum();
51 (inc_sum as i64).saturating_sub(dec_sum as i64)
52 }
53
54 pub fn get_increment(&self, replica_id: &K) -> u64 {
56 self.increments.get(replica_id).copied().unwrap_or(0)
57 }
58
59 pub fn get_decrement(&self, replica_id: &K) -> u64 {
61 self.decrements.get(replica_id).copied().unwrap_or(0)
62 }
63
64 pub fn increments(&self) -> &BTreeMap<K, u64> {
66 &self.increments
67 }
68
69 pub fn decrements(&self) -> &BTreeMap<K, u64> {
71 &self.decrements
72 }
73}
74
75impl<K: Ord + Clone> Default for PNCounter<K> {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl<K: Ord + Clone> Lattice for PNCounter<K> {
82 fn bottom() -> Self {
83 Self::new()
84 }
85
86 fn join(&self, other: &Self) -> Self {
89 let mut increments = self.increments.clone();
90 let mut decrements = self.decrements.clone();
91
92 for (k, v) in &other.increments {
94 increments
95 .entry(k.clone())
96 .and_modify(|e| *e = (*e).max(*v))
97 .or_insert(*v);
98 }
99
100 for (k, v) in &other.decrements {
101 decrements
102 .entry(k.clone())
103 .and_modify(|e| *e = (*e).max(*v))
104 .or_insert(*v);
105 }
106
107 Self {
108 increments,
109 decrements,
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_pncounter_basic_operations() {
120 let mut counter = PNCounter::new();
121
122 counter.increment("A", 5);
124 assert_eq!(counter.value(), 5);
125
126 counter.decrement("B", 2);
128 assert_eq!(counter.value(), 3);
129
130 counter.increment("A", 3);
132 assert_eq!(counter.value(), 6);
133 }
134
135 #[test]
136 fn test_pncounter_join_idempotent() {
137 let mut c1 = PNCounter::new();
138 c1.increment("A", 5);
139 c1.decrement("B", 2);
140
141 let joined = c1.join(&c1);
142 assert_eq!(joined.value(), c1.value());
143 assert_eq!(joined.value(), 3);
144 }
145
146 #[test]
147 fn test_pncounter_join_commutative() {
148 let mut c1 = PNCounter::new();
149 c1.increment("A", 5);
150
151 let mut c2 = PNCounter::new();
152 c2.increment("B", 3);
153 c2.decrement("A", 1);
154
155 let joined1 = c1.join(&c2);
156 let joined2 = c2.join(&c1);
157
158 assert_eq!(joined1.value(), joined2.value());
159 assert_eq!(joined1.get_increment(&"A"), 5);
160 assert_eq!(joined1.get_increment(&"B"), 3);
161 assert_eq!(joined1.get_decrement(&"A"), 1);
162 }
163
164 #[test]
165 fn test_pncounter_join_associative() {
166 let mut c1 = PNCounter::new();
167 c1.increment("A", 1);
168
169 let mut c2 = PNCounter::new();
170 c2.increment("B", 2);
171
172 let mut c3 = PNCounter::new();
173 c3.decrement("C", 1);
174
175 let left = c1.join(&c2).join(&c3);
176 let right = c1.join(&c2.join(&c3));
177
178 assert_eq!(left.value(), right.value());
179 }
180
181 #[test]
182 fn test_pncounter_bottom_is_identity() {
183 let mut counter = PNCounter::new();
184 counter.increment("A", 5);
185 counter.decrement("B", 2);
186
187 let bottom = PNCounter::bottom();
188 let joined = counter.join(&bottom);
189
190 assert_eq!(joined.value(), counter.value());
191 }
192
193 #[test]
194 fn test_pncounter_convergence_different_order() {
195 let mut c1 = PNCounter::new();
196 c1.increment("X", 10);
197 c1.decrement("Y", 3);
198
199 let mut c2 = PNCounter::new();
200 c2.increment("Z", 5);
201 c2.decrement("X", 2);
202
203 let mut state1 = PNCounter::bottom();
205 state1.join_assign(&c1);
206 state1.join_assign(&c2);
207
208 let mut state2 = PNCounter::bottom();
209 state2.join_assign(&c2);
210 state2.join_assign(&c1);
211
212 assert_eq!(state1.value(), state2.value());
213 }
214
215 #[test]
216 fn test_pncounter_serialization() {
217 let mut counter = PNCounter::new();
218 counter.increment("replica1", 100);
219 counter.decrement("replica2", 25);
220
221 let serialized = serde_json::to_string(&counter).unwrap();
222 let deserialized: PNCounter<String> = serde_json::from_str(&serialized).unwrap();
223
224 assert_eq!(deserialized.value(), counter.value());
225 assert_eq!(deserialized.get_increment(&"replica1".to_string()), 100);
226 assert_eq!(deserialized.get_decrement(&"replica2".to_string()), 25);
227 }
228}