tropical_gemm/types/
min_plus.rs

1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6/// TropicalMinPlus semiring: (ℝ ∪ {+∞}, min, +)
7///
8/// - Addition (⊕) = min
9/// - Multiplication (⊗) = +
10/// - Zero = +∞
11/// - One = 0
12///
13/// This is used for:
14/// - Shortest path algorithms (Dijkstra, Floyd-Warshall)
15/// - Dynamic programming with minimum cost
16#[derive(Copy, Clone, PartialEq)]
17#[repr(transparent)]
18pub struct TropicalMinPlus<T: TropicalScalar>(pub T);
19
20impl<T: TropicalScalar> TropicalMinPlus<T> {
21    /// Create a new TropicalMinPlus value.
22    #[inline(always)]
23    pub fn new(value: T) -> Self {
24        Self(value)
25    }
26}
27
28impl<T: TropicalScalar> TropicalSemiring for TropicalMinPlus<T> {
29    type Scalar = T;
30
31    #[inline(always)]
32    fn tropical_zero() -> Self {
33        Self(T::pos_infinity())
34    }
35
36    #[inline(always)]
37    fn tropical_one() -> Self {
38        Self(T::scalar_zero())
39    }
40
41    #[inline(always)]
42    fn tropical_add(self, rhs: Self) -> Self {
43        Self(self.0.scalar_min(rhs.0))
44    }
45
46    #[inline(always)]
47    fn tropical_mul(self, rhs: Self) -> Self {
48        Self(self.0.scalar_add(rhs.0))
49    }
50
51    #[inline(always)]
52    fn value(&self) -> T {
53        self.0
54    }
55
56    #[inline(always)]
57    fn from_scalar(s: T) -> Self {
58        Self(s)
59    }
60}
61
62impl<T: TropicalScalar> TropicalWithArgmax for TropicalMinPlus<T> {
63    type Index = u32;
64
65    #[inline(always)]
66    fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
67        // For min, we track argmin
68        if self.0 <= rhs.0 {
69            (self, self_idx)
70        } else {
71            (rhs, rhs_idx)
72        }
73    }
74}
75
76impl<T: TropicalScalar> SimdTropical for TropicalMinPlus<T> {
77    const SIMD_AVAILABLE: bool = true;
78    const SIMD_WIDTH: usize = 8;
79}
80
81impl<T: TropicalScalar> Add for TropicalMinPlus<T> {
82    type Output = Self;
83
84    #[inline(always)]
85    fn add(self, rhs: Self) -> Self::Output {
86        self.tropical_add(rhs)
87    }
88}
89
90impl<T: TropicalScalar> Mul for TropicalMinPlus<T> {
91    type Output = Self;
92
93    #[inline(always)]
94    fn mul(self, rhs: Self) -> Self::Output {
95        self.tropical_mul(rhs)
96    }
97}
98
99impl<T: TropicalScalar> Default for TropicalMinPlus<T> {
100    #[inline(always)]
101    fn default() -> Self {
102        Self::tropical_zero()
103    }
104}
105
106impl<T: TropicalScalar> fmt::Debug for TropicalMinPlus<T> {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(f, "TropicalMinPlus({})", self.0)
109    }
110}
111
112impl<T: TropicalScalar> fmt::Display for TropicalMinPlus<T> {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "{}", self.0)
115    }
116}
117
118impl<T: TropicalScalar> From<T> for TropicalMinPlus<T> {
119    #[inline(always)]
120    fn from(value: T) -> Self {
121        Self(value)
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_semiring_identity() {
131        let a = TropicalMinPlus::new(5.0f64);
132        let zero = TropicalMinPlus::tropical_zero();
133        let one = TropicalMinPlus::tropical_one();
134
135        // a ⊕ 0 = a
136        assert_eq!(a.tropical_add(zero), a);
137        // a ⊗ 1 = a
138        assert_eq!(a.tropical_mul(one), a);
139    }
140
141    #[test]
142    fn test_operations() {
143        let a = TropicalMinPlus::new(3.0f64);
144        let b = TropicalMinPlus::new(5.0f64);
145
146        // min(3, 5) = 3
147        assert_eq!(a.tropical_add(b).0, 3.0);
148        // 3 + 5 = 8
149        assert_eq!(a.tropical_mul(b).0, 8.0);
150    }
151
152    #[test]
153    fn test_shortest_path_scenario() {
154        // Simulating: path cost a=10, path cost b=5, combine = min(10,5) = 5
155        let a = TropicalMinPlus::new(10.0f64);
156        let b = TropicalMinPlus::new(5.0f64);
157        assert_eq!(a.tropical_add(b).0, 5.0);
158
159        // Extending a path: cost=5, edge=3, total = 5+3 = 8
160        let path = TropicalMinPlus::new(5.0f64);
161        let edge = TropicalMinPlus::new(3.0f64);
162        assert_eq!(path.tropical_mul(edge).0, 8.0);
163    }
164
165    #[test]
166    fn test_argmin_right_wins() {
167        // For MinPlus, argmax actually tracks argmin
168        let a = TropicalMinPlus::new(5.0f64);
169        let b = TropicalMinPlus::new(3.0f64);
170
171        let (result, idx) = a.tropical_add_argmax(0, b, 1);
172        assert_eq!(result.0, 3.0);
173        assert_eq!(idx, 1); // Right has smaller value
174    }
175
176    #[test]
177    fn test_argmin_left_wins() {
178        let a = TropicalMinPlus::new(2.0f64);
179        let b = TropicalMinPlus::new(7.0f64);
180
181        let (result, idx) = a.tropical_add_argmax(10, b, 20);
182        assert_eq!(result.0, 2.0);
183        assert_eq!(idx, 10); // Left has smaller value
184    }
185
186    #[test]
187    fn test_argmin_equal_values() {
188        // When values are equal, left (self) wins (<= comparison)
189        let a = TropicalMinPlus::new(5.0f64);
190        let b = TropicalMinPlus::new(5.0f64);
191
192        let (result, idx) = a.tropical_add_argmax(1, b, 2);
193        assert_eq!(result.0, 5.0);
194        assert_eq!(idx, 1); // Equal, so left (self) wins
195    }
196
197    #[test]
198    fn test_argmin_chain() {
199        // Simulate accumulating through k iterations - find minimum
200        let mut acc = TropicalMinPlus::tropical_zero(); // +inf
201        let mut idx = 0u32;
202
203        let values = [8.0, 3.0, 9.0, 5.0]; // Min is at index 1
204        for (k, &val) in values.iter().enumerate() {
205            let candidate = TropicalMinPlus::new(val);
206            (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
207        }
208
209        assert_eq!(acc.0, 3.0);
210        assert_eq!(idx, 1); // Index where min occurred
211    }
212
213    #[test]
214    fn test_argmin_pos_infinity() {
215        let a = TropicalMinPlus::tropical_zero(); // +inf
216        let b = TropicalMinPlus::new(100.0f64);
217
218        let (result, idx) = a.tropical_add_argmax(0, b, 1);
219        assert_eq!(result.0, 100.0);
220        assert_eq!(idx, 1); // 100 < +inf
221    }
222
223    #[test]
224    fn test_absorbing_zero() {
225        let a = TropicalMinPlus::new(5.0f64);
226        let zero = TropicalMinPlus::tropical_zero();
227
228        // a ⊗ 0 = a + (+inf) = +inf
229        let result = a.tropical_mul(zero);
230        assert!(result.0.is_infinite() && result.0 > 0.0);
231    }
232
233    #[test]
234    fn test_operator_overloads() {
235        let a = TropicalMinPlus::new(3.0f64);
236        let b = TropicalMinPlus::new(5.0f64);
237
238        // Add operator (min)
239        assert_eq!((a + b).0, 3.0);
240        assert_eq!((b + a).0, 3.0);
241
242        // Mul operator (add)
243        assert_eq!((a * b).0, 8.0);
244        assert_eq!((b * a).0, 8.0);
245    }
246
247    #[test]
248    fn test_default() {
249        let d = TropicalMinPlus::<f64>::default();
250        assert!(d.0.is_infinite() && d.0 > 0.0); // +inf
251        assert_eq!(d, TropicalMinPlus::tropical_zero());
252    }
253
254    #[test]
255    fn test_display_debug() {
256        let a = TropicalMinPlus::new(5.0f64);
257
258        assert_eq!(format!("{}", a), "5");
259        assert_eq!(format!("{:?}", a), "TropicalMinPlus(5)");
260    }
261
262    #[test]
263    fn test_from() {
264        let a: TropicalMinPlus<f64> = 5.0.into();
265        assert_eq!(a.0, 5.0);
266
267        let b = TropicalMinPlus::<f64>::from(3.0);
268        assert_eq!(b.0, 3.0);
269    }
270
271    #[test]
272    fn test_value_and_from_scalar() {
273        let a = TropicalMinPlus::new(5.0f64);
274        assert_eq!(a.value(), 5.0);
275
276        let b = TropicalMinPlus::<f64>::from_scalar(3.0);
277        assert_eq!(b.value(), 3.0);
278    }
279
280    #[test]
281    fn test_simd_tropical() {
282        assert!(TropicalMinPlus::<f64>::SIMD_AVAILABLE);
283        assert_eq!(TropicalMinPlus::<f64>::SIMD_WIDTH, 8);
284    }
285
286    #[test]
287    fn test_clone_copy() {
288        let a = TropicalMinPlus::new(5.0f64);
289        let a_copy = a;
290        let a_clone = a.clone();
291
292        assert_eq!(a, a_copy);
293        assert_eq!(a, a_clone);
294    }
295
296    #[test]
297    fn test_eq() {
298        let a1 = TropicalMinPlus::new(5.0f64);
299        let a2 = TropicalMinPlus::new(5.0f64);
300        let b = TropicalMinPlus::new(3.0f64);
301
302        assert_eq!(a1, a2);
303        assert_ne!(a1, b);
304    }
305
306    #[test]
307    fn test_f32() {
308        let a = TropicalMinPlus::new(3.0f32);
309        let b = TropicalMinPlus::new(5.0f32);
310
311        assert!((a.tropical_add(b).0 - 3.0).abs() < 1e-6);
312        assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
313    }
314}