tropical_gemm/types/
max_plus.rs

1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6/// TropicalMaxPlus semiring: (ℝ ∪ {-∞}, max, +)
7///
8/// - Addition (⊕) = max
9/// - Multiplication (⊗) = +
10/// - Zero = -∞
11/// - One = 0
12///
13/// This is the classic tropical semiring used in:
14/// - Viterbi algorithm
15/// - Shortest path algorithms (with negated weights)
16/// - Log-space probability computations
17#[derive(Copy, Clone, PartialEq)]
18#[repr(transparent)]
19pub struct TropicalMaxPlus<T: TropicalScalar>(pub T);
20
21impl<T: TropicalScalar> TropicalMaxPlus<T> {
22    /// Create a new TropicalMaxPlus value.
23    #[inline(always)]
24    pub fn new(value: T) -> Self {
25        Self(value)
26    }
27}
28
29impl<T: TropicalScalar> TropicalSemiring for TropicalMaxPlus<T> {
30    type Scalar = T;
31
32    #[inline(always)]
33    fn tropical_zero() -> Self {
34        Self(T::neg_infinity())
35    }
36
37    #[inline(always)]
38    fn tropical_one() -> Self {
39        Self(T::scalar_zero())
40    }
41
42    #[inline(always)]
43    fn tropical_add(self, rhs: Self) -> Self {
44        Self(self.0.scalar_max(rhs.0))
45    }
46
47    #[inline(always)]
48    fn tropical_mul(self, rhs: Self) -> Self {
49        Self(self.0.scalar_add(rhs.0))
50    }
51
52    #[inline(always)]
53    fn value(&self) -> T {
54        self.0
55    }
56
57    #[inline(always)]
58    fn from_scalar(s: T) -> Self {
59        Self(s)
60    }
61}
62
63impl<T: TropicalScalar> TropicalWithArgmax for TropicalMaxPlus<T> {
64    type Index = u32;
65
66    #[inline(always)]
67    fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
68        if self.0 >= rhs.0 {
69            (self, self_idx)
70        } else {
71            (rhs, rhs_idx)
72        }
73    }
74}
75
76impl<T: TropicalScalar> SimdTropical for TropicalMaxPlus<T> {
77    const SIMD_AVAILABLE: bool = true;
78    const SIMD_WIDTH: usize = 8; // f32x8 for AVX2
79}
80
81impl<T: TropicalScalar> Add for TropicalMaxPlus<T> {
82    type Output = Self;
83
84    #[inline(always)]
85    fn add(self, rhs: Self) -> Self::Output {
86        self.tropical_add(rhs)
87    }
88}
89
90impl<T: TropicalScalar> Mul for TropicalMaxPlus<T> {
91    type Output = Self;
92
93    #[inline(always)]
94    fn mul(self, rhs: Self) -> Self::Output {
95        self.tropical_mul(rhs)
96    }
97}
98
99impl<T: TropicalScalar> Default for TropicalMaxPlus<T> {
100    #[inline(always)]
101    fn default() -> Self {
102        Self::tropical_zero()
103    }
104}
105
106impl<T: TropicalScalar> fmt::Debug for TropicalMaxPlus<T> {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(f, "TropicalMaxPlus({})", self.0)
109    }
110}
111
112impl<T: TropicalScalar> fmt::Display for TropicalMaxPlus<T> {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "{}", self.0)
115    }
116}
117
118impl<T: TropicalScalar> From<T> for TropicalMaxPlus<T> {
119    #[inline(always)]
120    fn from(value: T) -> Self {
121        Self(value)
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_semiring_identity() {
131        let a = TropicalMaxPlus::new(5.0f64);
132        let zero = TropicalMaxPlus::tropical_zero();
133        let one = TropicalMaxPlus::tropical_one();
134
135        // a ⊕ 0 = a
136        assert_eq!(a.tropical_add(zero), a);
137        // a ⊗ 1 = a
138        assert_eq!(a.tropical_mul(one), a);
139    }
140
141    #[test]
142    fn test_operations() {
143        let a = TropicalMaxPlus::new(3.0f64);
144        let b = TropicalMaxPlus::new(5.0f64);
145
146        // max(3, 5) = 5
147        assert_eq!(a.tropical_add(b).0, 5.0);
148        // 3 + 5 = 8
149        assert_eq!(a.tropical_mul(b).0, 8.0);
150    }
151
152    #[test]
153    fn test_argmax() {
154        let a = TropicalMaxPlus::new(3.0f64);
155        let b = TropicalMaxPlus::new(5.0f64);
156
157        let (result, idx) = a.tropical_add_argmax(0, b, 1);
158        assert_eq!(result.0, 5.0);
159        assert_eq!(idx, 1);
160    }
161
162    #[test]
163    fn test_argmax_left_wins() {
164        let a = TropicalMaxPlus::new(7.0f64);
165        let b = TropicalMaxPlus::new(3.0f64);
166
167        let (result, idx) = a.tropical_add_argmax(10, b, 20);
168        assert_eq!(result.0, 7.0);
169        assert_eq!(idx, 10); // Left wins, keep left index
170    }
171
172    #[test]
173    fn test_argmax_equal_values() {
174        // When values are equal, left (self) wins (>= comparison)
175        let a = TropicalMaxPlus::new(5.0f64);
176        let b = TropicalMaxPlus::new(5.0f64);
177
178        let (result, idx) = a.tropical_add_argmax(1, b, 2);
179        assert_eq!(result.0, 5.0);
180        assert_eq!(idx, 1); // Equal, so left (self) wins
181    }
182
183    #[test]
184    fn test_argmax_chain() {
185        // Simulate accumulating through k iterations
186        let mut acc = TropicalMaxPlus::tropical_zero();
187        let mut idx = 0u32;
188
189        let values = [3.0, 7.0, 2.0, 5.0]; // Max is at index 1
190        for (k, &val) in values.iter().enumerate() {
191            let candidate = TropicalMaxPlus::new(val);
192            (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
193        }
194
195        assert_eq!(acc.0, 7.0);
196        assert_eq!(idx, 1); // Index where max occurred
197    }
198
199    #[test]
200    fn test_argmax_neg_infinity() {
201        let a = TropicalMaxPlus::tropical_zero(); // -inf
202        let b = TropicalMaxPlus::new(-100.0f64);
203
204        let (result, idx) = a.tropical_add_argmax(0, b, 1);
205        assert_eq!(result.0, -100.0);
206        assert_eq!(idx, 1); // -100 > -inf
207    }
208
209    #[test]
210    fn test_absorbing_zero() {
211        let a = TropicalMaxPlus::new(5.0f64);
212        let zero = TropicalMaxPlus::tropical_zero();
213
214        // a ⊗ 0 = a + (-inf) = -inf
215        // In tropical max-plus, multiplying by zero (adding -inf) gives -inf
216        let result = a.tropical_mul(zero);
217        assert!(result.0.is_infinite() && result.0 < 0.0);
218    }
219
220    #[test]
221    fn test_operator_overloads() {
222        let a = TropicalMaxPlus::new(3.0f64);
223        let b = TropicalMaxPlus::new(5.0f64);
224
225        // Add operator (max)
226        assert_eq!((a + b).0, 5.0);
227        assert_eq!((b + a).0, 5.0);
228
229        // Mul operator (add)
230        assert_eq!((a * b).0, 8.0);
231        assert_eq!((b * a).0, 8.0);
232    }
233
234    #[test]
235    fn test_default() {
236        let d = TropicalMaxPlus::<f64>::default();
237        assert!(d.0.is_infinite() && d.0 < 0.0); // -inf
238        assert_eq!(d, TropicalMaxPlus::tropical_zero());
239    }
240
241    #[test]
242    fn test_display_debug() {
243        let a = TropicalMaxPlus::new(5.0f64);
244
245        assert_eq!(format!("{}", a), "5");
246        assert_eq!(format!("{:?}", a), "TropicalMaxPlus(5)");
247    }
248
249    #[test]
250    fn test_from() {
251        let a: TropicalMaxPlus<f64> = 5.0.into();
252        assert_eq!(a.0, 5.0);
253
254        let b = TropicalMaxPlus::<f64>::from(3.0);
255        assert_eq!(b.0, 3.0);
256    }
257
258    #[test]
259    fn test_value_and_from_scalar() {
260        let a = TropicalMaxPlus::new(5.0f64);
261        assert_eq!(a.value(), 5.0);
262
263        let b = TropicalMaxPlus::<f64>::from_scalar(3.0);
264        assert_eq!(b.value(), 3.0);
265    }
266
267    #[test]
268    fn test_simd_tropical() {
269        assert!(TropicalMaxPlus::<f64>::SIMD_AVAILABLE);
270        assert_eq!(TropicalMaxPlus::<f64>::SIMD_WIDTH, 8);
271    }
272
273    #[test]
274    fn test_clone_copy() {
275        let a = TropicalMaxPlus::new(5.0f64);
276        let a_copy = a;
277        let a_clone = a.clone();
278
279        assert_eq!(a, a_copy);
280        assert_eq!(a, a_clone);
281    }
282
283    #[test]
284    fn test_eq() {
285        let a1 = TropicalMaxPlus::new(5.0f64);
286        let a2 = TropicalMaxPlus::new(5.0f64);
287        let b = TropicalMaxPlus::new(3.0f64);
288
289        assert_eq!(a1, a2);
290        assert_ne!(a1, b);
291    }
292
293    #[test]
294    fn test_f32() {
295        let a = TropicalMaxPlus::new(3.0f32);
296        let b = TropicalMaxPlus::new(5.0f32);
297
298        assert!((a.tropical_add(b).0 - 5.0).abs() < 1e-6);
299        assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
300    }
301}