1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6#[derive(Copy, Clone, PartialEq)]
20#[repr(C)]
21pub struct CountingTropical<T: TropicalScalar, C: TropicalScalar = T> {
22 pub value: T,
24 pub count: C,
26}
27
28impl<T: TropicalScalar, C: TropicalScalar> CountingTropical<T, C> {
29 #[inline(always)]
31 pub fn new(value: T, count: C) -> Self {
32 Self { value, count }
33 }
34
35 #[inline(always)]
37 pub fn from_value(value: T) -> Self {
38 Self {
39 value,
40 count: C::scalar_one(),
41 }
42 }
43}
44
45impl<T: TropicalScalar, C: TropicalScalar> TropicalSemiring for CountingTropical<T, C> {
46 type Scalar = T;
47
48 #[inline(always)]
49 fn tropical_zero() -> Self {
50 Self {
51 value: T::neg_infinity(),
52 count: C::scalar_zero(),
53 }
54 }
55
56 #[inline(always)]
57 fn tropical_one() -> Self {
58 Self {
59 value: T::scalar_zero(),
60 count: C::scalar_one(),
61 }
62 }
63
64 #[inline(always)]
65 fn tropical_add(self, rhs: Self) -> Self {
66 if self.value > rhs.value {
67 self
68 } else if self.value < rhs.value {
69 rhs
70 } else {
71 Self {
73 value: self.value,
74 count: self.count.scalar_add(rhs.count),
75 }
76 }
77 }
78
79 #[inline(always)]
80 fn tropical_mul(self, rhs: Self) -> Self {
81 Self {
82 value: self.value.scalar_add(rhs.value),
83 count: self.count.scalar_mul(rhs.count),
84 }
85 }
86
87 #[inline(always)]
88 fn value(&self) -> T {
89 self.value
90 }
91
92 #[inline(always)]
93 fn from_scalar(s: T) -> Self {
94 Self::from_value(s)
95 }
96}
97
98impl<T: TropicalScalar, C: TropicalScalar> TropicalWithArgmax for CountingTropical<T, C> {
99 type Index = u32;
100
101 #[inline(always)]
102 fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
103 if self.value > rhs.value {
104 (self, self_idx)
105 } else if self.value < rhs.value {
106 (rhs, rhs_idx)
107 } else {
108 (
110 Self {
111 value: self.value,
112 count: self.count.scalar_add(rhs.count),
113 },
114 self_idx,
115 )
116 }
117 }
118
119 #[inline(always)]
120 fn is_no_contribution(&self) -> bool {
121 self.value.is_drifted_neg_zero()
125 }
126}
127
128impl<T: TropicalScalar, C: TropicalScalar> SimdTropical for CountingTropical<T, C> {
129 const SIMD_AVAILABLE: bool = true;
131 const SIMD_WIDTH: usize = 8;
132}
133
134impl<T: TropicalScalar, C: TropicalScalar> Add for CountingTropical<T, C> {
135 type Output = Self;
136
137 #[inline(always)]
138 fn add(self, rhs: Self) -> Self::Output {
139 self.tropical_add(rhs)
140 }
141}
142
143impl<T: TropicalScalar, C: TropicalScalar> Mul for CountingTropical<T, C> {
144 type Output = Self;
145
146 #[inline(always)]
147 fn mul(self, rhs: Self) -> Self::Output {
148 self.tropical_mul(rhs)
149 }
150}
151
152impl<T: TropicalScalar, C: TropicalScalar> Default for CountingTropical<T, C> {
153 #[inline(always)]
154 fn default() -> Self {
155 Self::tropical_zero()
156 }
157}
158
159impl<T: TropicalScalar, C: TropicalScalar> fmt::Debug for CountingTropical<T, C> {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 write!(f, "CountingTropical({}, {})", self.value, self.count)
162 }
163}
164
165impl<T: TropicalScalar, C: TropicalScalar> fmt::Display for CountingTropical<T, C> {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 write!(f, "({}, {})", self.value, self.count)
168 }
169}
170
171impl<T: TropicalScalar, C: TropicalScalar> From<T> for CountingTropical<T, C> {
172 #[inline(always)]
173 fn from(value: T) -> Self {
174 Self::from_value(value)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_semiring_identity() {
184 let a = CountingTropical::<f64>::new(5.0, 2.0);
185 let zero = CountingTropical::tropical_zero();
186 let one = CountingTropical::tropical_one();
187
188 let result = a.tropical_add(zero);
190 assert_eq!(result.value, a.value);
191 assert_eq!(result.count, a.count);
192
193 let result = a.tropical_mul(one);
195 assert_eq!(result.value, a.value);
196 assert_eq!(result.count, a.count);
197 }
198
199 #[test]
200 fn test_multiplication() {
201 let a = CountingTropical::<f64>::new(3.0, 2.0);
202 let b = CountingTropical::<f64>::new(5.0, 3.0);
203
204 let result = a.tropical_mul(b);
205 assert_eq!(result.value, 8.0);
207 assert_eq!(result.count, 6.0);
209 }
210
211 #[test]
212 fn test_addition_different_values() {
213 let a = CountingTropical::<f64>::new(3.0, 2.0);
214 let b = CountingTropical::<f64>::new(5.0, 3.0);
215
216 let result = a.tropical_add(b);
217 assert_eq!(result.value, 5.0);
219 assert_eq!(result.count, 3.0);
220 }
221
222 #[test]
223 fn test_addition_equal_values() {
224 let a = CountingTropical::<f64>::new(5.0, 2.0);
225 let b = CountingTropical::<f64>::new(5.0, 3.0);
226
227 let result = a.tropical_add(b);
228 assert_eq!(result.value, 5.0);
230 assert_eq!(result.count, 5.0);
231 }
232
233 #[test]
234 fn test_addition_self_wins() {
235 let a = CountingTropical::<f64>::new(7.0, 1.0);
236 let b = CountingTropical::<f64>::new(5.0, 3.0);
237
238 let result = a.tropical_add(b);
239 assert_eq!(result.value, 7.0);
241 assert_eq!(result.count, 1.0);
242 }
243
244 #[test]
245 fn test_path_counting_example() {
246 let path1 = CountingTropical::<f64>::new(3.0, 1.0);
252 let path2 = CountingTropical::<f64>::new(3.0, 2.0);
253
254 let result = path1.tropical_add(path2);
255 assert_eq!(result.value, 3.0);
256 assert_eq!(result.count, 3.0);
257 }
258
259 #[test]
260 fn test_operator_overloads() {
261 let a = CountingTropical::<f64>::new(3.0, 2.0);
262 let b = CountingTropical::<f64>::new(5.0, 3.0);
263
264 let result = a + b;
266 assert_eq!(result.value, 5.0);
267 assert_eq!(result.count, 3.0);
268
269 let result = a * b;
271 assert_eq!(result.value, 8.0);
272 assert_eq!(result.count, 6.0);
273 }
274
275 #[test]
276 fn test_default() {
277 let d = CountingTropical::<f64>::default();
278 assert!(d.value.is_infinite() && d.value < 0.0); assert_eq!(d.count, 0.0);
280 }
281
282 #[test]
283 fn test_display_debug() {
284 let a = CountingTropical::<f64>::new(3.0, 2.0);
285
286 assert_eq!(format!("{}", a), "(3, 2)");
287 assert_eq!(format!("{:?}", a), "CountingTropical(3, 2)");
288 }
289
290 #[test]
291 fn test_from() {
292 let a: CountingTropical<f64> = 5.0.into();
293 assert_eq!(a.value, 5.0);
294 assert_eq!(a.count, 1.0); let b = CountingTropical::<f64>::from(3.0);
297 assert_eq!(b.value, 3.0);
298 assert_eq!(b.count, 1.0);
299 }
300
301 #[test]
302 fn test_from_value() {
303 let a = CountingTropical::<f64>::from_value(7.0);
304 assert_eq!(a.value, 7.0);
305 assert_eq!(a.count, 1.0);
306 }
307
308 #[test]
309 fn test_value_and_from_scalar() {
310 let a = CountingTropical::<f64>::new(5.0, 2.0);
311 assert_eq!(a.value(), 5.0);
312
313 let b = CountingTropical::<f64>::from_scalar(3.0);
314 assert_eq!(b.value(), 3.0);
315 assert_eq!(b.count, 1.0);
316 }
317
318 #[test]
319 fn test_argmax_self_wins() {
320 let a = CountingTropical::<f64>::new(7.0, 2.0);
321 let b = CountingTropical::<f64>::new(3.0, 1.0);
322
323 let (result, idx) = a.tropical_add_argmax(1, b, 2);
324 assert_eq!(result.value, 7.0);
325 assert_eq!(result.count, 2.0);
326 assert_eq!(idx, 1);
327 }
328
329 #[test]
330 fn test_argmax_rhs_wins() {
331 let a = CountingTropical::<f64>::new(3.0, 1.0);
332 let b = CountingTropical::<f64>::new(7.0, 2.0);
333
334 let (result, idx) = a.tropical_add_argmax(1, b, 2);
335 assert_eq!(result.value, 7.0);
336 assert_eq!(result.count, 2.0);
337 assert_eq!(idx, 2);
338 }
339
340 #[test]
341 fn test_argmax_equal_counts_added() {
342 let a = CountingTropical::<f64>::new(5.0, 2.0);
344 let b = CountingTropical::<f64>::new(5.0, 3.0);
345
346 let (result, idx) = a.tropical_add_argmax(1, b, 2);
347 assert_eq!(result.value, 5.0);
348 assert_eq!(result.count, 5.0); assert_eq!(idx, 1); }
351
352 #[test]
353 fn test_argmax_chain() {
354 let mut acc = CountingTropical::<f64>::tropical_zero();
355 let mut idx = 0u32;
356
357 let values = [(3.0, 1.0), (7.0, 2.0), (7.0, 3.0), (5.0, 1.0)];
359 for (k, &(val, count)) in values.iter().enumerate() {
360 let candidate = CountingTropical::new(val, count);
361 (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
362 }
363
364 assert_eq!(acc.value, 7.0);
367 assert_eq!(acc.count, 5.0);
368 assert_eq!(idx, 1); }
370
371 #[test]
372 fn test_simd_tropical() {
373 assert!(CountingTropical::<f64>::SIMD_AVAILABLE);
374 assert_eq!(CountingTropical::<f64>::SIMD_WIDTH, 8);
375 }
376
377 #[test]
378 fn test_clone_copy() {
379 let a = CountingTropical::<f64>::new(5.0, 2.0);
380 let a_copy = a;
381 let a_clone = a.clone();
382
383 assert_eq!(a.value, a_copy.value);
384 assert_eq!(a.count, a_copy.count);
385 assert_eq!(a.value, a_clone.value);
386 assert_eq!(a.count, a_clone.count);
387 }
388
389 #[test]
390 fn test_eq() {
391 let a1 = CountingTropical::<f64>::new(5.0, 2.0);
392 let a2 = CountingTropical::<f64>::new(5.0, 2.0);
393 let b = CountingTropical::<f64>::new(5.0, 3.0);
394
395 assert_eq!(a1, a2);
396 assert_ne!(a1, b);
397 }
398
399 #[test]
400 fn test_f32() {
401 let a = CountingTropical::<f32>::new(3.0, 2.0);
402 let b = CountingTropical::<f32>::new(5.0, 3.0);
403
404 let result = a.tropical_mul(b);
405 assert!((result.value - 8.0).abs() < 1e-6);
406 assert!((result.count - 6.0).abs() < 1e-6);
407 }
408
409 #[test]
410 fn test_different_count_type() {
411 let a = CountingTropical::<f64, f32>::new(3.0, 2.0);
413 let b = CountingTropical::<f64, f32>::new(5.0, 3.0);
414
415 let result = a.tropical_mul(b);
416 assert_eq!(result.value, 8.0);
417 assert!((result.count - 6.0).abs() < 1e-6);
418 }
419}