Skip to main content

tropical_gemm/types/
counting.rs

1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6/// CountingTropical semiring: tracks both the tropical value and the count of optimal paths.
7///
8/// For TropicalMaxPlus semantics:
9/// - Multiplication: (n₁, c₁) ⊗ (n₂, c₂) = (n₁ + n₂, c₁ × c₂)
10/// - Addition: (n₁, c₁) ⊕ (n₂, c₂) =
11///   - if n₁ > n₂: (n₁, c₁)
12///   - if n₁ < n₂: (n₂, c₂)
13///   - if n₁ = n₂: (n₁, c₁ + c₂)
14///
15/// This is used for:
16/// - Counting optimal paths in dynamic programming
17/// - Computing partition functions
18/// - Gradient computations in certain neural network architectures
19#[derive(Copy, Clone, PartialEq)]
20#[repr(C)]
21pub struct CountingTropical<T: TropicalScalar, C: TropicalScalar = T> {
22    /// The tropical value (using MaxPlus semantics).
23    pub value: T,
24    /// The count of paths achieving this value.
25    pub count: C,
26}
27
28impl<T: TropicalScalar, C: TropicalScalar> CountingTropical<T, C> {
29    /// Create a new CountingTropical value.
30    #[inline(always)]
31    pub fn new(value: T, count: C) -> Self {
32        Self { value, count }
33    }
34
35    /// Create a CountingTropical from a single value with count 1.
36    #[inline(always)]
37    pub fn from_value(value: T) -> Self {
38        Self {
39            value,
40            count: C::scalar_one(),
41        }
42    }
43}
44
45impl<T: TropicalScalar, C: TropicalScalar> TropicalSemiring for CountingTropical<T, C> {
46    type Scalar = T;
47
48    #[inline(always)]
49    fn tropical_zero() -> Self {
50        Self {
51            value: T::neg_infinity(),
52            count: C::scalar_zero(),
53        }
54    }
55
56    #[inline(always)]
57    fn tropical_one() -> Self {
58        Self {
59            value: T::scalar_zero(),
60            count: C::scalar_one(),
61        }
62    }
63
64    #[inline(always)]
65    fn tropical_add(self, rhs: Self) -> Self {
66        if self.value > rhs.value {
67            self
68        } else if self.value < rhs.value {
69            rhs
70        } else {
71            // Equal values: add counts
72            Self {
73                value: self.value,
74                count: self.count.scalar_add(rhs.count),
75            }
76        }
77    }
78
79    #[inline(always)]
80    fn tropical_mul(self, rhs: Self) -> Self {
81        Self {
82            value: self.value.scalar_add(rhs.value),
83            count: self.count.scalar_mul(rhs.count),
84        }
85    }
86
87    #[inline(always)]
88    fn value(&self) -> T {
89        self.value
90    }
91
92    #[inline(always)]
93    fn from_scalar(s: T) -> Self {
94        Self::from_value(s)
95    }
96}
97
98impl<T: TropicalScalar, C: TropicalScalar> TropicalWithArgmax for CountingTropical<T, C> {
99    type Index = u32;
100
101    #[inline(always)]
102    fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
103        if self.value > rhs.value {
104            (self, self_idx)
105        } else if self.value < rhs.value {
106            (rhs, rhs_idx)
107        } else {
108            // Equal values: add counts, keep first index
109            (
110                Self {
111                    value: self.value,
112                    count: self.count.scalar_add(rhs.count),
113                },
114                self_idx,
115            )
116        }
117    }
118
119    #[inline(always)]
120    fn is_no_contribution(&self) -> bool {
121        // CountingTropical uses MaxPlus value semantics (zero = -∞), so a
122        // no-contribution cell's value drifts under the guard-free `+` exactly
123        // like TropicalMaxPlus and must be canonicalized the same way.
124        self.value.is_drifted_neg_zero()
125    }
126}
127
128impl<T: TropicalScalar, C: TropicalScalar> SimdTropical for CountingTropical<T, C> {
129    // SIMD for CountingTropical requires SOA layout
130    const SIMD_AVAILABLE: bool = true;
131    const SIMD_WIDTH: usize = 8;
132}
133
134impl<T: TropicalScalar, C: TropicalScalar> Add for CountingTropical<T, C> {
135    type Output = Self;
136
137    #[inline(always)]
138    fn add(self, rhs: Self) -> Self::Output {
139        self.tropical_add(rhs)
140    }
141}
142
143impl<T: TropicalScalar, C: TropicalScalar> Mul for CountingTropical<T, C> {
144    type Output = Self;
145
146    #[inline(always)]
147    fn mul(self, rhs: Self) -> Self::Output {
148        self.tropical_mul(rhs)
149    }
150}
151
152impl<T: TropicalScalar, C: TropicalScalar> Default for CountingTropical<T, C> {
153    #[inline(always)]
154    fn default() -> Self {
155        Self::tropical_zero()
156    }
157}
158
159impl<T: TropicalScalar, C: TropicalScalar> fmt::Debug for CountingTropical<T, C> {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        write!(f, "CountingTropical({}, {})", self.value, self.count)
162    }
163}
164
165impl<T: TropicalScalar, C: TropicalScalar> fmt::Display for CountingTropical<T, C> {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        write!(f, "({}, {})", self.value, self.count)
168    }
169}
170
171impl<T: TropicalScalar, C: TropicalScalar> From<T> for CountingTropical<T, C> {
172    #[inline(always)]
173    fn from(value: T) -> Self {
174        Self::from_value(value)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_semiring_identity() {
184        let a = CountingTropical::<f64>::new(5.0, 2.0);
185        let zero = CountingTropical::tropical_zero();
186        let one = CountingTropical::tropical_one();
187
188        // a ⊕ 0 = a
189        let result = a.tropical_add(zero);
190        assert_eq!(result.value, a.value);
191        assert_eq!(result.count, a.count);
192
193        // a ⊗ 1 = a
194        let result = a.tropical_mul(one);
195        assert_eq!(result.value, a.value);
196        assert_eq!(result.count, a.count);
197    }
198
199    #[test]
200    fn test_multiplication() {
201        let a = CountingTropical::<f64>::new(3.0, 2.0);
202        let b = CountingTropical::<f64>::new(5.0, 3.0);
203
204        let result = a.tropical_mul(b);
205        // value = 3 + 5 = 8
206        assert_eq!(result.value, 8.0);
207        // count = 2 * 3 = 6
208        assert_eq!(result.count, 6.0);
209    }
210
211    #[test]
212    fn test_addition_different_values() {
213        let a = CountingTropical::<f64>::new(3.0, 2.0);
214        let b = CountingTropical::<f64>::new(5.0, 3.0);
215
216        let result = a.tropical_add(b);
217        // max(3, 5) = 5, keep count of winner
218        assert_eq!(result.value, 5.0);
219        assert_eq!(result.count, 3.0);
220    }
221
222    #[test]
223    fn test_addition_equal_values() {
224        let a = CountingTropical::<f64>::new(5.0, 2.0);
225        let b = CountingTropical::<f64>::new(5.0, 3.0);
226
227        let result = a.tropical_add(b);
228        // same value, add counts
229        assert_eq!(result.value, 5.0);
230        assert_eq!(result.count, 5.0);
231    }
232
233    #[test]
234    fn test_addition_self_wins() {
235        let a = CountingTropical::<f64>::new(7.0, 1.0);
236        let b = CountingTropical::<f64>::new(5.0, 3.0);
237
238        let result = a.tropical_add(b);
239        // max(7, 5) = 7, keep count of winner
240        assert_eq!(result.value, 7.0);
241        assert_eq!(result.count, 1.0);
242    }
243
244    #[test]
245    fn test_path_counting_example() {
246        // Example: counting paths in a graph
247        // Path A->B has value 3, count 1 (one path)
248        // Path A->C->B has value 3, count 2 (two equivalent paths)
249        // Total paths A->B with optimal value: 1 + 2 = 3
250
251        let path1 = CountingTropical::<f64>::new(3.0, 1.0);
252        let path2 = CountingTropical::<f64>::new(3.0, 2.0);
253
254        let result = path1.tropical_add(path2);
255        assert_eq!(result.value, 3.0);
256        assert_eq!(result.count, 3.0);
257    }
258
259    #[test]
260    fn test_operator_overloads() {
261        let a = CountingTropical::<f64>::new(3.0, 2.0);
262        let b = CountingTropical::<f64>::new(5.0, 3.0);
263
264        // Add operator
265        let result = a + b;
266        assert_eq!(result.value, 5.0);
267        assert_eq!(result.count, 3.0);
268
269        // Mul operator
270        let result = a * b;
271        assert_eq!(result.value, 8.0);
272        assert_eq!(result.count, 6.0);
273    }
274
275    #[test]
276    fn test_default() {
277        let d = CountingTropical::<f64>::default();
278        assert!(d.value.is_infinite() && d.value < 0.0); // -inf
279        assert_eq!(d.count, 0.0);
280    }
281
282    #[test]
283    fn test_display_debug() {
284        let a = CountingTropical::<f64>::new(3.0, 2.0);
285
286        assert_eq!(format!("{}", a), "(3, 2)");
287        assert_eq!(format!("{:?}", a), "CountingTropical(3, 2)");
288    }
289
290    #[test]
291    fn test_from() {
292        let a: CountingTropical<f64> = 5.0.into();
293        assert_eq!(a.value, 5.0);
294        assert_eq!(a.count, 1.0); // Default count is 1
295
296        let b = CountingTropical::<f64>::from(3.0);
297        assert_eq!(b.value, 3.0);
298        assert_eq!(b.count, 1.0);
299    }
300
301    #[test]
302    fn test_from_value() {
303        let a = CountingTropical::<f64>::from_value(7.0);
304        assert_eq!(a.value, 7.0);
305        assert_eq!(a.count, 1.0);
306    }
307
308    #[test]
309    fn test_value_and_from_scalar() {
310        let a = CountingTropical::<f64>::new(5.0, 2.0);
311        assert_eq!(a.value(), 5.0);
312
313        let b = CountingTropical::<f64>::from_scalar(3.0);
314        assert_eq!(b.value(), 3.0);
315        assert_eq!(b.count, 1.0);
316    }
317
318    #[test]
319    fn test_argmax_self_wins() {
320        let a = CountingTropical::<f64>::new(7.0, 2.0);
321        let b = CountingTropical::<f64>::new(3.0, 1.0);
322
323        let (result, idx) = a.tropical_add_argmax(1, b, 2);
324        assert_eq!(result.value, 7.0);
325        assert_eq!(result.count, 2.0);
326        assert_eq!(idx, 1);
327    }
328
329    #[test]
330    fn test_argmax_rhs_wins() {
331        let a = CountingTropical::<f64>::new(3.0, 1.0);
332        let b = CountingTropical::<f64>::new(7.0, 2.0);
333
334        let (result, idx) = a.tropical_add_argmax(1, b, 2);
335        assert_eq!(result.value, 7.0);
336        assert_eq!(result.count, 2.0);
337        assert_eq!(idx, 2);
338    }
339
340    #[test]
341    fn test_argmax_equal_counts_added() {
342        // Equal values: counts are added, first index is kept
343        let a = CountingTropical::<f64>::new(5.0, 2.0);
344        let b = CountingTropical::<f64>::new(5.0, 3.0);
345
346        let (result, idx) = a.tropical_add_argmax(1, b, 2);
347        assert_eq!(result.value, 5.0);
348        assert_eq!(result.count, 5.0); // 2 + 3
349        assert_eq!(idx, 1); // First index is kept
350    }
351
352    #[test]
353    fn test_argmax_chain() {
354        let mut acc = CountingTropical::<f64>::tropical_zero();
355        let mut idx = 0u32;
356
357        // Values with different counts
358        let values = [(3.0, 1.0), (7.0, 2.0), (7.0, 3.0), (5.0, 1.0)];
359        for (k, &(val, count)) in values.iter().enumerate() {
360            let candidate = CountingTropical::new(val, count);
361            (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
362        }
363
364        // Max value is 7.0, first encountered at k=1
365        // Counts: 2 + 3 = 5 (both k=1 and k=2 have value 7.0)
366        assert_eq!(acc.value, 7.0);
367        assert_eq!(acc.count, 5.0);
368        assert_eq!(idx, 1); // First index where max occurred
369    }
370
371    #[test]
372    fn test_simd_tropical() {
373        assert!(CountingTropical::<f64>::SIMD_AVAILABLE);
374        assert_eq!(CountingTropical::<f64>::SIMD_WIDTH, 8);
375    }
376
377    #[test]
378    fn test_clone_copy() {
379        let a = CountingTropical::<f64>::new(5.0, 2.0);
380        let a_copy = a;
381        let a_clone = a.clone();
382
383        assert_eq!(a.value, a_copy.value);
384        assert_eq!(a.count, a_copy.count);
385        assert_eq!(a.value, a_clone.value);
386        assert_eq!(a.count, a_clone.count);
387    }
388
389    #[test]
390    fn test_eq() {
391        let a1 = CountingTropical::<f64>::new(5.0, 2.0);
392        let a2 = CountingTropical::<f64>::new(5.0, 2.0);
393        let b = CountingTropical::<f64>::new(5.0, 3.0);
394
395        assert_eq!(a1, a2);
396        assert_ne!(a1, b);
397    }
398
399    #[test]
400    fn test_f32() {
401        let a = CountingTropical::<f32>::new(3.0, 2.0);
402        let b = CountingTropical::<f32>::new(5.0, 3.0);
403
404        let result = a.tropical_mul(b);
405        assert!((result.value - 8.0).abs() < 1e-6);
406        assert!((result.count - 6.0).abs() < 1e-6);
407    }
408
409    #[test]
410    fn test_different_count_type() {
411        // Use different types for value and count
412        let a = CountingTropical::<f64, f32>::new(3.0, 2.0);
413        let b = CountingTropical::<f64, f32>::new(5.0, 3.0);
414
415        let result = a.tropical_mul(b);
416        assert_eq!(result.value, 8.0);
417        assert!((result.count - 6.0).abs() < 1e-6);
418    }
419}