1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6#[derive(Copy, Clone, PartialEq)]
20#[repr(C)]
21pub struct CountingTropical<T: TropicalScalar, C: TropicalScalar = T> {
22 pub value: T,
24 pub count: C,
26}
27
28impl<T: TropicalScalar, C: TropicalScalar> CountingTropical<T, C> {
29 #[inline(always)]
31 pub fn new(value: T, count: C) -> Self {
32 Self { value, count }
33 }
34
35 #[inline(always)]
37 pub fn from_value(value: T) -> Self {
38 Self {
39 value,
40 count: C::scalar_one(),
41 }
42 }
43}
44
45impl<T: TropicalScalar, C: TropicalScalar> TropicalSemiring for CountingTropical<T, C> {
46 type Scalar = T;
47
48 #[inline(always)]
49 fn tropical_zero() -> Self {
50 Self {
51 value: T::neg_infinity(),
52 count: C::scalar_zero(),
53 }
54 }
55
56 #[inline(always)]
57 fn tropical_one() -> Self {
58 Self {
59 value: T::scalar_zero(),
60 count: C::scalar_one(),
61 }
62 }
63
64 #[inline(always)]
65 fn tropical_add(self, rhs: Self) -> Self {
66 if self.value > rhs.value {
67 self
68 } else if self.value < rhs.value {
69 rhs
70 } else {
71 Self {
73 value: self.value,
74 count: self.count.scalar_add(rhs.count),
75 }
76 }
77 }
78
79 #[inline(always)]
80 fn tropical_mul(self, rhs: Self) -> Self {
81 Self {
82 value: self.value.scalar_add(rhs.value),
83 count: self.count.scalar_mul(rhs.count),
84 }
85 }
86
87 #[inline(always)]
88 fn value(&self) -> T {
89 self.value
90 }
91
92 #[inline(always)]
93 fn from_scalar(s: T) -> Self {
94 Self::from_value(s)
95 }
96}
97
98impl<T: TropicalScalar, C: TropicalScalar> TropicalWithArgmax for CountingTropical<T, C> {
99 type Index = u32;
100
101 #[inline(always)]
102 fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
103 if self.value > rhs.value {
104 (self, self_idx)
105 } else if self.value < rhs.value {
106 (rhs, rhs_idx)
107 } else {
108 (
110 Self {
111 value: self.value,
112 count: self.count.scalar_add(rhs.count),
113 },
114 self_idx,
115 )
116 }
117 }
118}
119
120impl<T: TropicalScalar, C: TropicalScalar> SimdTropical for CountingTropical<T, C> {
121 const SIMD_AVAILABLE: bool = true;
123 const SIMD_WIDTH: usize = 8;
124}
125
126impl<T: TropicalScalar, C: TropicalScalar> Add for CountingTropical<T, C> {
127 type Output = Self;
128
129 #[inline(always)]
130 fn add(self, rhs: Self) -> Self::Output {
131 self.tropical_add(rhs)
132 }
133}
134
135impl<T: TropicalScalar, C: TropicalScalar> Mul for CountingTropical<T, C> {
136 type Output = Self;
137
138 #[inline(always)]
139 fn mul(self, rhs: Self) -> Self::Output {
140 self.tropical_mul(rhs)
141 }
142}
143
144impl<T: TropicalScalar, C: TropicalScalar> Default for CountingTropical<T, C> {
145 #[inline(always)]
146 fn default() -> Self {
147 Self::tropical_zero()
148 }
149}
150
151impl<T: TropicalScalar, C: TropicalScalar> fmt::Debug for CountingTropical<T, C> {
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 write!(f, "CountingTropical({}, {})", self.value, self.count)
154 }
155}
156
157impl<T: TropicalScalar, C: TropicalScalar> fmt::Display for CountingTropical<T, C> {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 write!(f, "({}, {})", self.value, self.count)
160 }
161}
162
163impl<T: TropicalScalar, C: TropicalScalar> From<T> for CountingTropical<T, C> {
164 #[inline(always)]
165 fn from(value: T) -> Self {
166 Self::from_value(value)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_semiring_identity() {
176 let a = CountingTropical::<f64>::new(5.0, 2.0);
177 let zero = CountingTropical::tropical_zero();
178 let one = CountingTropical::tropical_one();
179
180 let result = a.tropical_add(zero);
182 assert_eq!(result.value, a.value);
183 assert_eq!(result.count, a.count);
184
185 let result = a.tropical_mul(one);
187 assert_eq!(result.value, a.value);
188 assert_eq!(result.count, a.count);
189 }
190
191 #[test]
192 fn test_multiplication() {
193 let a = CountingTropical::<f64>::new(3.0, 2.0);
194 let b = CountingTropical::<f64>::new(5.0, 3.0);
195
196 let result = a.tropical_mul(b);
197 assert_eq!(result.value, 8.0);
199 assert_eq!(result.count, 6.0);
201 }
202
203 #[test]
204 fn test_addition_different_values() {
205 let a = CountingTropical::<f64>::new(3.0, 2.0);
206 let b = CountingTropical::<f64>::new(5.0, 3.0);
207
208 let result = a.tropical_add(b);
209 assert_eq!(result.value, 5.0);
211 assert_eq!(result.count, 3.0);
212 }
213
214 #[test]
215 fn test_addition_equal_values() {
216 let a = CountingTropical::<f64>::new(5.0, 2.0);
217 let b = CountingTropical::<f64>::new(5.0, 3.0);
218
219 let result = a.tropical_add(b);
220 assert_eq!(result.value, 5.0);
222 assert_eq!(result.count, 5.0);
223 }
224
225 #[test]
226 fn test_addition_self_wins() {
227 let a = CountingTropical::<f64>::new(7.0, 1.0);
228 let b = CountingTropical::<f64>::new(5.0, 3.0);
229
230 let result = a.tropical_add(b);
231 assert_eq!(result.value, 7.0);
233 assert_eq!(result.count, 1.0);
234 }
235
236 #[test]
237 fn test_path_counting_example() {
238 let path1 = CountingTropical::<f64>::new(3.0, 1.0);
244 let path2 = CountingTropical::<f64>::new(3.0, 2.0);
245
246 let result = path1.tropical_add(path2);
247 assert_eq!(result.value, 3.0);
248 assert_eq!(result.count, 3.0);
249 }
250
251 #[test]
252 fn test_operator_overloads() {
253 let a = CountingTropical::<f64>::new(3.0, 2.0);
254 let b = CountingTropical::<f64>::new(5.0, 3.0);
255
256 let result = a + b;
258 assert_eq!(result.value, 5.0);
259 assert_eq!(result.count, 3.0);
260
261 let result = a * b;
263 assert_eq!(result.value, 8.0);
264 assert_eq!(result.count, 6.0);
265 }
266
267 #[test]
268 fn test_default() {
269 let d = CountingTropical::<f64>::default();
270 assert!(d.value.is_infinite() && d.value < 0.0); assert_eq!(d.count, 0.0);
272 }
273
274 #[test]
275 fn test_display_debug() {
276 let a = CountingTropical::<f64>::new(3.0, 2.0);
277
278 assert_eq!(format!("{}", a), "(3, 2)");
279 assert_eq!(format!("{:?}", a), "CountingTropical(3, 2)");
280 }
281
282 #[test]
283 fn test_from() {
284 let a: CountingTropical<f64> = 5.0.into();
285 assert_eq!(a.value, 5.0);
286 assert_eq!(a.count, 1.0); let b = CountingTropical::<f64>::from(3.0);
289 assert_eq!(b.value, 3.0);
290 assert_eq!(b.count, 1.0);
291 }
292
293 #[test]
294 fn test_from_value() {
295 let a = CountingTropical::<f64>::from_value(7.0);
296 assert_eq!(a.value, 7.0);
297 assert_eq!(a.count, 1.0);
298 }
299
300 #[test]
301 fn test_value_and_from_scalar() {
302 let a = CountingTropical::<f64>::new(5.0, 2.0);
303 assert_eq!(a.value(), 5.0);
304
305 let b = CountingTropical::<f64>::from_scalar(3.0);
306 assert_eq!(b.value(), 3.0);
307 assert_eq!(b.count, 1.0);
308 }
309
310 #[test]
311 fn test_argmax_self_wins() {
312 let a = CountingTropical::<f64>::new(7.0, 2.0);
313 let b = CountingTropical::<f64>::new(3.0, 1.0);
314
315 let (result, idx) = a.tropical_add_argmax(1, b, 2);
316 assert_eq!(result.value, 7.0);
317 assert_eq!(result.count, 2.0);
318 assert_eq!(idx, 1);
319 }
320
321 #[test]
322 fn test_argmax_rhs_wins() {
323 let a = CountingTropical::<f64>::new(3.0, 1.0);
324 let b = CountingTropical::<f64>::new(7.0, 2.0);
325
326 let (result, idx) = a.tropical_add_argmax(1, b, 2);
327 assert_eq!(result.value, 7.0);
328 assert_eq!(result.count, 2.0);
329 assert_eq!(idx, 2);
330 }
331
332 #[test]
333 fn test_argmax_equal_counts_added() {
334 let a = CountingTropical::<f64>::new(5.0, 2.0);
336 let b = CountingTropical::<f64>::new(5.0, 3.0);
337
338 let (result, idx) = a.tropical_add_argmax(1, b, 2);
339 assert_eq!(result.value, 5.0);
340 assert_eq!(result.count, 5.0); assert_eq!(idx, 1); }
343
344 #[test]
345 fn test_argmax_chain() {
346 let mut acc = CountingTropical::<f64>::tropical_zero();
347 let mut idx = 0u32;
348
349 let values = [(3.0, 1.0), (7.0, 2.0), (7.0, 3.0), (5.0, 1.0)];
351 for (k, &(val, count)) in values.iter().enumerate() {
352 let candidate = CountingTropical::new(val, count);
353 (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
354 }
355
356 assert_eq!(acc.value, 7.0);
359 assert_eq!(acc.count, 5.0);
360 assert_eq!(idx, 1); }
362
363 #[test]
364 fn test_simd_tropical() {
365 assert!(CountingTropical::<f64>::SIMD_AVAILABLE);
366 assert_eq!(CountingTropical::<f64>::SIMD_WIDTH, 8);
367 }
368
369 #[test]
370 fn test_clone_copy() {
371 let a = CountingTropical::<f64>::new(5.0, 2.0);
372 let a_copy = a;
373 let a_clone = a.clone();
374
375 assert_eq!(a.value, a_copy.value);
376 assert_eq!(a.count, a_copy.count);
377 assert_eq!(a.value, a_clone.value);
378 assert_eq!(a.count, a_clone.count);
379 }
380
381 #[test]
382 fn test_eq() {
383 let a1 = CountingTropical::<f64>::new(5.0, 2.0);
384 let a2 = CountingTropical::<f64>::new(5.0, 2.0);
385 let b = CountingTropical::<f64>::new(5.0, 3.0);
386
387 assert_eq!(a1, a2);
388 assert_ne!(a1, b);
389 }
390
391 #[test]
392 fn test_f32() {
393 let a = CountingTropical::<f32>::new(3.0, 2.0);
394 let b = CountingTropical::<f32>::new(5.0, 3.0);
395
396 let result = a.tropical_mul(b);
397 assert!((result.value - 8.0).abs() < 1e-6);
398 assert!((result.count - 6.0).abs() < 1e-6);
399 }
400
401 #[test]
402 fn test_different_count_type() {
403 let a = CountingTropical::<f64, f32>::new(3.0, 2.0);
405 let b = CountingTropical::<f64, f32>::new(5.0, 3.0);
406
407 let result = a.tropical_mul(b);
408 assert_eq!(result.value, 8.0);
409 assert!((result.count - 6.0).abs() < 1e-6);
410 }
411}