tropical_gemm/types/
counting.rs

1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6/// CountingTropical semiring: tracks both the tropical value and the count of optimal paths.
7///
8/// For TropicalMaxPlus semantics:
9/// - Multiplication: (n₁, c₁) ⊗ (n₂, c₂) = (n₁ + n₂, c₁ × c₂)
10/// - Addition: (n₁, c₁) ⊕ (n₂, c₂) =
11///   - if n₁ > n₂: (n₁, c₁)
12///   - if n₁ < n₂: (n₂, c₂)
13///   - if n₁ = n₂: (n₁, c₁ + c₂)
14///
15/// This is used for:
16/// - Counting optimal paths in dynamic programming
17/// - Computing partition functions
18/// - Gradient computations in certain neural network architectures
19#[derive(Copy, Clone, PartialEq)]
20#[repr(C)]
21pub struct CountingTropical<T: TropicalScalar, C: TropicalScalar = T> {
22    /// The tropical value (using MaxPlus semantics).
23    pub value: T,
24    /// The count of paths achieving this value.
25    pub count: C,
26}
27
28impl<T: TropicalScalar, C: TropicalScalar> CountingTropical<T, C> {
29    /// Create a new CountingTropical value.
30    #[inline(always)]
31    pub fn new(value: T, count: C) -> Self {
32        Self { value, count }
33    }
34
35    /// Create a CountingTropical from a single value with count 1.
36    #[inline(always)]
37    pub fn from_value(value: T) -> Self {
38        Self {
39            value,
40            count: C::scalar_one(),
41        }
42    }
43}
44
45impl<T: TropicalScalar, C: TropicalScalar> TropicalSemiring for CountingTropical<T, C> {
46    type Scalar = T;
47
48    #[inline(always)]
49    fn tropical_zero() -> Self {
50        Self {
51            value: T::neg_infinity(),
52            count: C::scalar_zero(),
53        }
54    }
55
56    #[inline(always)]
57    fn tropical_one() -> Self {
58        Self {
59            value: T::scalar_zero(),
60            count: C::scalar_one(),
61        }
62    }
63
64    #[inline(always)]
65    fn tropical_add(self, rhs: Self) -> Self {
66        if self.value > rhs.value {
67            self
68        } else if self.value < rhs.value {
69            rhs
70        } else {
71            // Equal values: add counts
72            Self {
73                value: self.value,
74                count: self.count.scalar_add(rhs.count),
75            }
76        }
77    }
78
79    #[inline(always)]
80    fn tropical_mul(self, rhs: Self) -> Self {
81        Self {
82            value: self.value.scalar_add(rhs.value),
83            count: self.count.scalar_mul(rhs.count),
84        }
85    }
86
87    #[inline(always)]
88    fn value(&self) -> T {
89        self.value
90    }
91
92    #[inline(always)]
93    fn from_scalar(s: T) -> Self {
94        Self::from_value(s)
95    }
96}
97
98impl<T: TropicalScalar, C: TropicalScalar> TropicalWithArgmax for CountingTropical<T, C> {
99    type Index = u32;
100
101    #[inline(always)]
102    fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
103        if self.value > rhs.value {
104            (self, self_idx)
105        } else if self.value < rhs.value {
106            (rhs, rhs_idx)
107        } else {
108            // Equal values: add counts, keep first index
109            (
110                Self {
111                    value: self.value,
112                    count: self.count.scalar_add(rhs.count),
113                },
114                self_idx,
115            )
116        }
117    }
118}
119
120impl<T: TropicalScalar, C: TropicalScalar> SimdTropical for CountingTropical<T, C> {
121    // SIMD for CountingTropical requires SOA layout
122    const SIMD_AVAILABLE: bool = true;
123    const SIMD_WIDTH: usize = 8;
124}
125
126impl<T: TropicalScalar, C: TropicalScalar> Add for CountingTropical<T, C> {
127    type Output = Self;
128
129    #[inline(always)]
130    fn add(self, rhs: Self) -> Self::Output {
131        self.tropical_add(rhs)
132    }
133}
134
135impl<T: TropicalScalar, C: TropicalScalar> Mul for CountingTropical<T, C> {
136    type Output = Self;
137
138    #[inline(always)]
139    fn mul(self, rhs: Self) -> Self::Output {
140        self.tropical_mul(rhs)
141    }
142}
143
144impl<T: TropicalScalar, C: TropicalScalar> Default for CountingTropical<T, C> {
145    #[inline(always)]
146    fn default() -> Self {
147        Self::tropical_zero()
148    }
149}
150
151impl<T: TropicalScalar, C: TropicalScalar> fmt::Debug for CountingTropical<T, C> {
152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        write!(f, "CountingTropical({}, {})", self.value, self.count)
154    }
155}
156
157impl<T: TropicalScalar, C: TropicalScalar> fmt::Display for CountingTropical<T, C> {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        write!(f, "({}, {})", self.value, self.count)
160    }
161}
162
163impl<T: TropicalScalar, C: TropicalScalar> From<T> for CountingTropical<T, C> {
164    #[inline(always)]
165    fn from(value: T) -> Self {
166        Self::from_value(value)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_semiring_identity() {
176        let a = CountingTropical::<f64>::new(5.0, 2.0);
177        let zero = CountingTropical::tropical_zero();
178        let one = CountingTropical::tropical_one();
179
180        // a ⊕ 0 = a
181        let result = a.tropical_add(zero);
182        assert_eq!(result.value, a.value);
183        assert_eq!(result.count, a.count);
184
185        // a ⊗ 1 = a
186        let result = a.tropical_mul(one);
187        assert_eq!(result.value, a.value);
188        assert_eq!(result.count, a.count);
189    }
190
191    #[test]
192    fn test_multiplication() {
193        let a = CountingTropical::<f64>::new(3.0, 2.0);
194        let b = CountingTropical::<f64>::new(5.0, 3.0);
195
196        let result = a.tropical_mul(b);
197        // value = 3 + 5 = 8
198        assert_eq!(result.value, 8.0);
199        // count = 2 * 3 = 6
200        assert_eq!(result.count, 6.0);
201    }
202
203    #[test]
204    fn test_addition_different_values() {
205        let a = CountingTropical::<f64>::new(3.0, 2.0);
206        let b = CountingTropical::<f64>::new(5.0, 3.0);
207
208        let result = a.tropical_add(b);
209        // max(3, 5) = 5, keep count of winner
210        assert_eq!(result.value, 5.0);
211        assert_eq!(result.count, 3.0);
212    }
213
214    #[test]
215    fn test_addition_equal_values() {
216        let a = CountingTropical::<f64>::new(5.0, 2.0);
217        let b = CountingTropical::<f64>::new(5.0, 3.0);
218
219        let result = a.tropical_add(b);
220        // same value, add counts
221        assert_eq!(result.value, 5.0);
222        assert_eq!(result.count, 5.0);
223    }
224
225    #[test]
226    fn test_addition_self_wins() {
227        let a = CountingTropical::<f64>::new(7.0, 1.0);
228        let b = CountingTropical::<f64>::new(5.0, 3.0);
229
230        let result = a.tropical_add(b);
231        // max(7, 5) = 7, keep count of winner
232        assert_eq!(result.value, 7.0);
233        assert_eq!(result.count, 1.0);
234    }
235
236    #[test]
237    fn test_path_counting_example() {
238        // Example: counting paths in a graph
239        // Path A->B has value 3, count 1 (one path)
240        // Path A->C->B has value 3, count 2 (two equivalent paths)
241        // Total paths A->B with optimal value: 1 + 2 = 3
242
243        let path1 = CountingTropical::<f64>::new(3.0, 1.0);
244        let path2 = CountingTropical::<f64>::new(3.0, 2.0);
245
246        let result = path1.tropical_add(path2);
247        assert_eq!(result.value, 3.0);
248        assert_eq!(result.count, 3.0);
249    }
250
251    #[test]
252    fn test_operator_overloads() {
253        let a = CountingTropical::<f64>::new(3.0, 2.0);
254        let b = CountingTropical::<f64>::new(5.0, 3.0);
255
256        // Add operator
257        let result = a + b;
258        assert_eq!(result.value, 5.0);
259        assert_eq!(result.count, 3.0);
260
261        // Mul operator
262        let result = a * b;
263        assert_eq!(result.value, 8.0);
264        assert_eq!(result.count, 6.0);
265    }
266
267    #[test]
268    fn test_default() {
269        let d = CountingTropical::<f64>::default();
270        assert!(d.value.is_infinite() && d.value < 0.0); // -inf
271        assert_eq!(d.count, 0.0);
272    }
273
274    #[test]
275    fn test_display_debug() {
276        let a = CountingTropical::<f64>::new(3.0, 2.0);
277
278        assert_eq!(format!("{}", a), "(3, 2)");
279        assert_eq!(format!("{:?}", a), "CountingTropical(3, 2)");
280    }
281
282    #[test]
283    fn test_from() {
284        let a: CountingTropical<f64> = 5.0.into();
285        assert_eq!(a.value, 5.0);
286        assert_eq!(a.count, 1.0); // Default count is 1
287
288        let b = CountingTropical::<f64>::from(3.0);
289        assert_eq!(b.value, 3.0);
290        assert_eq!(b.count, 1.0);
291    }
292
293    #[test]
294    fn test_from_value() {
295        let a = CountingTropical::<f64>::from_value(7.0);
296        assert_eq!(a.value, 7.0);
297        assert_eq!(a.count, 1.0);
298    }
299
300    #[test]
301    fn test_value_and_from_scalar() {
302        let a = CountingTropical::<f64>::new(5.0, 2.0);
303        assert_eq!(a.value(), 5.0);
304
305        let b = CountingTropical::<f64>::from_scalar(3.0);
306        assert_eq!(b.value(), 3.0);
307        assert_eq!(b.count, 1.0);
308    }
309
310    #[test]
311    fn test_argmax_self_wins() {
312        let a = CountingTropical::<f64>::new(7.0, 2.0);
313        let b = CountingTropical::<f64>::new(3.0, 1.0);
314
315        let (result, idx) = a.tropical_add_argmax(1, b, 2);
316        assert_eq!(result.value, 7.0);
317        assert_eq!(result.count, 2.0);
318        assert_eq!(idx, 1);
319    }
320
321    #[test]
322    fn test_argmax_rhs_wins() {
323        let a = CountingTropical::<f64>::new(3.0, 1.0);
324        let b = CountingTropical::<f64>::new(7.0, 2.0);
325
326        let (result, idx) = a.tropical_add_argmax(1, b, 2);
327        assert_eq!(result.value, 7.0);
328        assert_eq!(result.count, 2.0);
329        assert_eq!(idx, 2);
330    }
331
332    #[test]
333    fn test_argmax_equal_counts_added() {
334        // Equal values: counts are added, first index is kept
335        let a = CountingTropical::<f64>::new(5.0, 2.0);
336        let b = CountingTropical::<f64>::new(5.0, 3.0);
337
338        let (result, idx) = a.tropical_add_argmax(1, b, 2);
339        assert_eq!(result.value, 5.0);
340        assert_eq!(result.count, 5.0); // 2 + 3
341        assert_eq!(idx, 1); // First index is kept
342    }
343
344    #[test]
345    fn test_argmax_chain() {
346        let mut acc = CountingTropical::<f64>::tropical_zero();
347        let mut idx = 0u32;
348
349        // Values with different counts
350        let values = [(3.0, 1.0), (7.0, 2.0), (7.0, 3.0), (5.0, 1.0)];
351        for (k, &(val, count)) in values.iter().enumerate() {
352            let candidate = CountingTropical::new(val, count);
353            (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
354        }
355
356        // Max value is 7.0, first encountered at k=1
357        // Counts: 2 + 3 = 5 (both k=1 and k=2 have value 7.0)
358        assert_eq!(acc.value, 7.0);
359        assert_eq!(acc.count, 5.0);
360        assert_eq!(idx, 1); // First index where max occurred
361    }
362
363    #[test]
364    fn test_simd_tropical() {
365        assert!(CountingTropical::<f64>::SIMD_AVAILABLE);
366        assert_eq!(CountingTropical::<f64>::SIMD_WIDTH, 8);
367    }
368
369    #[test]
370    fn test_clone_copy() {
371        let a = CountingTropical::<f64>::new(5.0, 2.0);
372        let a_copy = a;
373        let a_clone = a.clone();
374
375        assert_eq!(a.value, a_copy.value);
376        assert_eq!(a.count, a_copy.count);
377        assert_eq!(a.value, a_clone.value);
378        assert_eq!(a.count, a_clone.count);
379    }
380
381    #[test]
382    fn test_eq() {
383        let a1 = CountingTropical::<f64>::new(5.0, 2.0);
384        let a2 = CountingTropical::<f64>::new(5.0, 2.0);
385        let b = CountingTropical::<f64>::new(5.0, 3.0);
386
387        assert_eq!(a1, a2);
388        assert_ne!(a1, b);
389    }
390
391    #[test]
392    fn test_f32() {
393        let a = CountingTropical::<f32>::new(3.0, 2.0);
394        let b = CountingTropical::<f32>::new(5.0, 3.0);
395
396        let result = a.tropical_mul(b);
397        assert!((result.value - 8.0).abs() < 1e-6);
398        assert!((result.count - 6.0).abs() < 1e-6);
399    }
400
401    #[test]
402    fn test_different_count_type() {
403        // Use different types for value and count
404        let a = CountingTropical::<f64, f32>::new(3.0, 2.0);
405        let b = CountingTropical::<f64, f32>::new(5.0, 3.0);
406
407        let result = a.tropical_mul(b);
408        assert_eq!(result.value, 8.0);
409        assert!((result.count - 6.0).abs() < 1e-6);
410    }
411}