Skip to main content

tropical_gemm/types/
max_plus.rs

1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6/// TropicalMaxPlus semiring: (ℝ ∪ {-∞}, max, +)
7///
8/// - Addition (⊕) = max
9/// - Multiplication (⊗) = +
10/// - Zero = -∞
11/// - One = 0
12///
13/// This is the classic tropical semiring used in:
14/// - Viterbi algorithm
15/// - Shortest path algorithms (with negated weights)
16/// - Log-space probability computations
17#[derive(Copy, Clone, PartialEq)]
18#[repr(transparent)]
19pub struct TropicalMaxPlus<T: TropicalScalar>(pub T);
20
21impl<T: TropicalScalar> TropicalMaxPlus<T> {
22    /// Create a new TropicalMaxPlus value.
23    #[inline(always)]
24    pub fn new(value: T) -> Self {
25        Self(value)
26    }
27}
28
29impl<T: TropicalScalar> TropicalSemiring for TropicalMaxPlus<T> {
30    type Scalar = T;
31
32    #[inline(always)]
33    fn tropical_zero() -> Self {
34        Self(T::neg_infinity())
35    }
36
37    #[inline(always)]
38    fn tropical_one() -> Self {
39        Self(T::scalar_zero())
40    }
41
42    #[inline(always)]
43    fn tropical_add(self, rhs: Self) -> Self {
44        Self(self.0.scalar_max(rhs.0))
45    }
46
47    #[inline(always)]
48    fn tropical_mul(self, rhs: Self) -> Self {
49        Self(self.0.scalar_add(rhs.0))
50    }
51
52    #[inline(always)]
53    fn value(&self) -> T {
54        self.0
55    }
56
57    #[inline(always)]
58    fn from_scalar(s: T) -> Self {
59        Self(s)
60    }
61}
62
63impl<T: TropicalScalar> TropicalWithArgmax for TropicalMaxPlus<T> {
64    type Index = u32;
65
66    #[inline(always)]
67    fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
68        if self.0 >= rhs.0 {
69            (self, self_idx)
70        } else {
71            (rhs, rhs_idx)
72        }
73    }
74
75    #[inline(always)]
76    fn is_no_contribution(&self) -> bool {
77        self.0.is_drifted_neg_zero()
78    }
79}
80
81impl<T: TropicalScalar> SimdTropical for TropicalMaxPlus<T> {
82    const SIMD_AVAILABLE: bool = true;
83    const SIMD_WIDTH: usize = 8; // f32x8 for AVX2
84}
85
86impl<T: TropicalScalar> Add for TropicalMaxPlus<T> {
87    type Output = Self;
88
89    #[inline(always)]
90    fn add(self, rhs: Self) -> Self::Output {
91        self.tropical_add(rhs)
92    }
93}
94
95impl<T: TropicalScalar> Mul for TropicalMaxPlus<T> {
96    type Output = Self;
97
98    #[inline(always)]
99    fn mul(self, rhs: Self) -> Self::Output {
100        self.tropical_mul(rhs)
101    }
102}
103
104impl<T: TropicalScalar> Default for TropicalMaxPlus<T> {
105    #[inline(always)]
106    fn default() -> Self {
107        Self::tropical_zero()
108    }
109}
110
111impl<T: TropicalScalar> fmt::Debug for TropicalMaxPlus<T> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        write!(f, "TropicalMaxPlus({})", self.0)
114    }
115}
116
117impl<T: TropicalScalar> fmt::Display for TropicalMaxPlus<T> {
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        write!(f, "{}", self.0)
120    }
121}
122
123impl<T: TropicalScalar> From<T> for TropicalMaxPlus<T> {
124    #[inline(always)]
125    fn from(value: T) -> Self {
126        Self(value)
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_semiring_identity() {
136        let a = TropicalMaxPlus::new(5.0f64);
137        let zero = TropicalMaxPlus::tropical_zero();
138        let one = TropicalMaxPlus::tropical_one();
139
140        // a ⊕ 0 = a
141        assert_eq!(a.tropical_add(zero), a);
142        // a ⊗ 1 = a
143        assert_eq!(a.tropical_mul(one), a);
144    }
145
146    #[test]
147    fn test_operations() {
148        let a = TropicalMaxPlus::new(3.0f64);
149        let b = TropicalMaxPlus::new(5.0f64);
150
151        // max(3, 5) = 5
152        assert_eq!(a.tropical_add(b).0, 5.0);
153        // 3 + 5 = 8
154        assert_eq!(a.tropical_mul(b).0, 8.0);
155    }
156
157    #[test]
158    fn test_argmax() {
159        let a = TropicalMaxPlus::new(3.0f64);
160        let b = TropicalMaxPlus::new(5.0f64);
161
162        let (result, idx) = a.tropical_add_argmax(0, b, 1);
163        assert_eq!(result.0, 5.0);
164        assert_eq!(idx, 1);
165    }
166
167    #[test]
168    fn test_argmax_left_wins() {
169        let a = TropicalMaxPlus::new(7.0f64);
170        let b = TropicalMaxPlus::new(3.0f64);
171
172        let (result, idx) = a.tropical_add_argmax(10, b, 20);
173        assert_eq!(result.0, 7.0);
174        assert_eq!(idx, 10); // Left wins, keep left index
175    }
176
177    #[test]
178    fn test_argmax_equal_values() {
179        // When values are equal, left (self) wins (>= comparison)
180        let a = TropicalMaxPlus::new(5.0f64);
181        let b = TropicalMaxPlus::new(5.0f64);
182
183        let (result, idx) = a.tropical_add_argmax(1, b, 2);
184        assert_eq!(result.0, 5.0);
185        assert_eq!(idx, 1); // Equal, so left (self) wins
186    }
187
188    #[test]
189    fn test_argmax_chain() {
190        // Simulate accumulating through k iterations
191        let mut acc = TropicalMaxPlus::tropical_zero();
192        let mut idx = 0u32;
193
194        let values = [3.0, 7.0, 2.0, 5.0]; // Max is at index 1
195        for (k, &val) in values.iter().enumerate() {
196            let candidate = TropicalMaxPlus::new(val);
197            (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
198        }
199
200        assert_eq!(acc.0, 7.0);
201        assert_eq!(idx, 1); // Index where max occurred
202    }
203
204    #[test]
205    fn test_argmax_neg_infinity() {
206        let a = TropicalMaxPlus::tropical_zero(); // -inf
207        let b = TropicalMaxPlus::new(-100.0f64);
208
209        let (result, idx) = a.tropical_add_argmax(0, b, 1);
210        assert_eq!(result.0, -100.0);
211        assert_eq!(idx, 1); // -100 > -inf
212    }
213
214    #[test]
215    fn test_absorbing_zero() {
216        let a = TropicalMaxPlus::new(5.0f64);
217        let zero = TropicalMaxPlus::tropical_zero();
218
219        // a ⊗ 0 = a + (-inf) = -inf
220        // In tropical max-plus, multiplying by zero (adding -inf) gives -inf
221        let result = a.tropical_mul(zero);
222        assert!(result.0.is_infinite() && result.0 < 0.0);
223    }
224
225    #[test]
226    fn test_operator_overloads() {
227        let a = TropicalMaxPlus::new(3.0f64);
228        let b = TropicalMaxPlus::new(5.0f64);
229
230        // Add operator (max)
231        assert_eq!((a + b).0, 5.0);
232        assert_eq!((b + a).0, 5.0);
233
234        // Mul operator (add)
235        assert_eq!((a * b).0, 8.0);
236        assert_eq!((b * a).0, 8.0);
237    }
238
239    #[test]
240    fn test_default() {
241        let d = TropicalMaxPlus::<f64>::default();
242        assert!(d.0.is_infinite() && d.0 < 0.0); // -inf
243        assert_eq!(d, TropicalMaxPlus::tropical_zero());
244    }
245
246    #[test]
247    fn test_display_debug() {
248        let a = TropicalMaxPlus::new(5.0f64);
249
250        assert_eq!(format!("{}", a), "5");
251        assert_eq!(format!("{:?}", a), "TropicalMaxPlus(5)");
252    }
253
254    #[test]
255    fn test_from() {
256        let a: TropicalMaxPlus<f64> = 5.0.into();
257        assert_eq!(a.0, 5.0);
258
259        let b = TropicalMaxPlus::<f64>::from(3.0);
260        assert_eq!(b.0, 3.0);
261    }
262
263    #[test]
264    fn test_value_and_from_scalar() {
265        let a = TropicalMaxPlus::new(5.0f64);
266        assert_eq!(a.value(), 5.0);
267
268        let b = TropicalMaxPlus::<f64>::from_scalar(3.0);
269        assert_eq!(b.value(), 3.0);
270    }
271
272    #[test]
273    fn test_simd_tropical() {
274        assert!(TropicalMaxPlus::<f64>::SIMD_AVAILABLE);
275        assert_eq!(TropicalMaxPlus::<f64>::SIMD_WIDTH, 8);
276    }
277
278    #[test]
279    fn test_clone_copy() {
280        let a = TropicalMaxPlus::new(5.0f64);
281        let a_copy = a;
282        let a_clone = a.clone();
283
284        assert_eq!(a, a_copy);
285        assert_eq!(a, a_clone);
286    }
287
288    #[test]
289    fn test_eq() {
290        let a1 = TropicalMaxPlus::new(5.0f64);
291        let a2 = TropicalMaxPlus::new(5.0f64);
292        let b = TropicalMaxPlus::new(3.0f64);
293
294        assert_eq!(a1, a2);
295        assert_ne!(a1, b);
296    }
297
298    #[test]
299    fn test_f32() {
300        let a = TropicalMaxPlus::new(3.0f32);
301        let b = TropicalMaxPlus::new(5.0f32);
302
303        assert!((a.tropical_add(b).0 - 5.0).abs() < 1e-6);
304        assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
305    }
306}