Skip to main content

tropical_gemm/types/
min_plus.rs

1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6/// TropicalMinPlus semiring: (ℝ ∪ {+∞}, min, +)
7///
8/// - Addition (⊕) = min
9/// - Multiplication (⊗) = +
10/// - Zero = +∞
11/// - One = 0
12///
13/// This is used for:
14/// - Shortest path algorithms (Dijkstra, Floyd-Warshall)
15/// - Dynamic programming with minimum cost
16#[derive(Copy, Clone, PartialEq)]
17#[repr(transparent)]
18pub struct TropicalMinPlus<T: TropicalScalar>(pub T);
19
20impl<T: TropicalScalar> TropicalMinPlus<T> {
21    /// Create a new TropicalMinPlus value.
22    #[inline(always)]
23    pub fn new(value: T) -> Self {
24        Self(value)
25    }
26}
27
28impl<T: TropicalScalar> TropicalSemiring for TropicalMinPlus<T> {
29    type Scalar = T;
30
31    #[inline(always)]
32    fn tropical_zero() -> Self {
33        Self(T::pos_infinity())
34    }
35
36    #[inline(always)]
37    fn tropical_one() -> Self {
38        Self(T::scalar_zero())
39    }
40
41    #[inline(always)]
42    fn tropical_add(self, rhs: Self) -> Self {
43        Self(self.0.scalar_min(rhs.0))
44    }
45
46    #[inline(always)]
47    fn tropical_mul(self, rhs: Self) -> Self {
48        Self(self.0.scalar_add(rhs.0))
49    }
50
51    #[inline(always)]
52    fn value(&self) -> T {
53        self.0
54    }
55
56    #[inline(always)]
57    fn from_scalar(s: T) -> Self {
58        Self(s)
59    }
60}
61
62impl<T: TropicalScalar> TropicalWithArgmax for TropicalMinPlus<T> {
63    type Index = u32;
64
65    #[inline(always)]
66    fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
67        // For min, we track argmin
68        if self.0 <= rhs.0 {
69            (self, self_idx)
70        } else {
71            (rhs, rhs_idx)
72        }
73    }
74
75    #[inline(always)]
76    fn is_no_contribution(&self) -> bool {
77        self.0.is_drifted_pos_zero()
78    }
79}
80
81impl<T: TropicalScalar> SimdTropical for TropicalMinPlus<T> {
82    const SIMD_AVAILABLE: bool = true;
83    const SIMD_WIDTH: usize = 8;
84}
85
86impl<T: TropicalScalar> Add for TropicalMinPlus<T> {
87    type Output = Self;
88
89    #[inline(always)]
90    fn add(self, rhs: Self) -> Self::Output {
91        self.tropical_add(rhs)
92    }
93}
94
95impl<T: TropicalScalar> Mul for TropicalMinPlus<T> {
96    type Output = Self;
97
98    #[inline(always)]
99    fn mul(self, rhs: Self) -> Self::Output {
100        self.tropical_mul(rhs)
101    }
102}
103
104impl<T: TropicalScalar> Default for TropicalMinPlus<T> {
105    #[inline(always)]
106    fn default() -> Self {
107        Self::tropical_zero()
108    }
109}
110
111impl<T: TropicalScalar> fmt::Debug for TropicalMinPlus<T> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        write!(f, "TropicalMinPlus({})", self.0)
114    }
115}
116
117impl<T: TropicalScalar> fmt::Display for TropicalMinPlus<T> {
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        write!(f, "{}", self.0)
120    }
121}
122
123impl<T: TropicalScalar> From<T> for TropicalMinPlus<T> {
124    #[inline(always)]
125    fn from(value: T) -> Self {
126        Self(value)
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_semiring_identity() {
136        let a = TropicalMinPlus::new(5.0f64);
137        let zero = TropicalMinPlus::tropical_zero();
138        let one = TropicalMinPlus::tropical_one();
139
140        // a ⊕ 0 = a
141        assert_eq!(a.tropical_add(zero), a);
142        // a ⊗ 1 = a
143        assert_eq!(a.tropical_mul(one), a);
144    }
145
146    #[test]
147    fn test_operations() {
148        let a = TropicalMinPlus::new(3.0f64);
149        let b = TropicalMinPlus::new(5.0f64);
150
151        // min(3, 5) = 3
152        assert_eq!(a.tropical_add(b).0, 3.0);
153        // 3 + 5 = 8
154        assert_eq!(a.tropical_mul(b).0, 8.0);
155    }
156
157    #[test]
158    fn test_shortest_path_scenario() {
159        // Simulating: path cost a=10, path cost b=5, combine = min(10,5) = 5
160        let a = TropicalMinPlus::new(10.0f64);
161        let b = TropicalMinPlus::new(5.0f64);
162        assert_eq!(a.tropical_add(b).0, 5.0);
163
164        // Extending a path: cost=5, edge=3, total = 5+3 = 8
165        let path = TropicalMinPlus::new(5.0f64);
166        let edge = TropicalMinPlus::new(3.0f64);
167        assert_eq!(path.tropical_mul(edge).0, 8.0);
168    }
169
170    #[test]
171    fn test_argmin_right_wins() {
172        // For MinPlus, argmax actually tracks argmin
173        let a = TropicalMinPlus::new(5.0f64);
174        let b = TropicalMinPlus::new(3.0f64);
175
176        let (result, idx) = a.tropical_add_argmax(0, b, 1);
177        assert_eq!(result.0, 3.0);
178        assert_eq!(idx, 1); // Right has smaller value
179    }
180
181    #[test]
182    fn test_argmin_left_wins() {
183        let a = TropicalMinPlus::new(2.0f64);
184        let b = TropicalMinPlus::new(7.0f64);
185
186        let (result, idx) = a.tropical_add_argmax(10, b, 20);
187        assert_eq!(result.0, 2.0);
188        assert_eq!(idx, 10); // Left has smaller value
189    }
190
191    #[test]
192    fn test_argmin_equal_values() {
193        // When values are equal, left (self) wins (<= comparison)
194        let a = TropicalMinPlus::new(5.0f64);
195        let b = TropicalMinPlus::new(5.0f64);
196
197        let (result, idx) = a.tropical_add_argmax(1, b, 2);
198        assert_eq!(result.0, 5.0);
199        assert_eq!(idx, 1); // Equal, so left (self) wins
200    }
201
202    #[test]
203    fn test_argmin_chain() {
204        // Simulate accumulating through k iterations - find minimum
205        let mut acc = TropicalMinPlus::tropical_zero(); // +inf
206        let mut idx = 0u32;
207
208        let values = [8.0, 3.0, 9.0, 5.0]; // Min is at index 1
209        for (k, &val) in values.iter().enumerate() {
210            let candidate = TropicalMinPlus::new(val);
211            (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
212        }
213
214        assert_eq!(acc.0, 3.0);
215        assert_eq!(idx, 1); // Index where min occurred
216    }
217
218    #[test]
219    fn test_argmin_pos_infinity() {
220        let a = TropicalMinPlus::tropical_zero(); // +inf
221        let b = TropicalMinPlus::new(100.0f64);
222
223        let (result, idx) = a.tropical_add_argmax(0, b, 1);
224        assert_eq!(result.0, 100.0);
225        assert_eq!(idx, 1); // 100 < +inf
226    }
227
228    #[test]
229    fn test_absorbing_zero() {
230        let a = TropicalMinPlus::new(5.0f64);
231        let zero = TropicalMinPlus::tropical_zero();
232
233        // a ⊗ 0 = a + (+inf) = +inf
234        let result = a.tropical_mul(zero);
235        assert!(result.0.is_infinite() && result.0 > 0.0);
236    }
237
238    #[test]
239    fn test_operator_overloads() {
240        let a = TropicalMinPlus::new(3.0f64);
241        let b = TropicalMinPlus::new(5.0f64);
242
243        // Add operator (min)
244        assert_eq!((a + b).0, 3.0);
245        assert_eq!((b + a).0, 3.0);
246
247        // Mul operator (add)
248        assert_eq!((a * b).0, 8.0);
249        assert_eq!((b * a).0, 8.0);
250    }
251
252    #[test]
253    fn test_default() {
254        let d = TropicalMinPlus::<f64>::default();
255        assert!(d.0.is_infinite() && d.0 > 0.0); // +inf
256        assert_eq!(d, TropicalMinPlus::tropical_zero());
257    }
258
259    #[test]
260    fn test_display_debug() {
261        let a = TropicalMinPlus::new(5.0f64);
262
263        assert_eq!(format!("{}", a), "5");
264        assert_eq!(format!("{:?}", a), "TropicalMinPlus(5)");
265    }
266
267    #[test]
268    fn test_from() {
269        let a: TropicalMinPlus<f64> = 5.0.into();
270        assert_eq!(a.0, 5.0);
271
272        let b = TropicalMinPlus::<f64>::from(3.0);
273        assert_eq!(b.0, 3.0);
274    }
275
276    #[test]
277    fn test_value_and_from_scalar() {
278        let a = TropicalMinPlus::new(5.0f64);
279        assert_eq!(a.value(), 5.0);
280
281        let b = TropicalMinPlus::<f64>::from_scalar(3.0);
282        assert_eq!(b.value(), 3.0);
283    }
284
285    #[test]
286    fn test_simd_tropical() {
287        assert!(TropicalMinPlus::<f64>::SIMD_AVAILABLE);
288        assert_eq!(TropicalMinPlus::<f64>::SIMD_WIDTH, 8);
289    }
290
291    #[test]
292    fn test_clone_copy() {
293        let a = TropicalMinPlus::new(5.0f64);
294        let a_copy = a;
295        let a_clone = a.clone();
296
297        assert_eq!(a, a_copy);
298        assert_eq!(a, a_clone);
299    }
300
301    #[test]
302    fn test_eq() {
303        let a1 = TropicalMinPlus::new(5.0f64);
304        let a2 = TropicalMinPlus::new(5.0f64);
305        let b = TropicalMinPlus::new(3.0f64);
306
307        assert_eq!(a1, a2);
308        assert_ne!(a1, b);
309    }
310
311    #[test]
312    fn test_f32() {
313        let a = TropicalMinPlus::new(3.0f32);
314        let b = TropicalMinPlus::new(5.0f32);
315
316        assert!((a.tropical_add(b).0 - 3.0).abs() < 1e-6);
317        assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
318    }
319}