tropical_gemm/types/
min_plus.rs1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6#[derive(Copy, Clone, PartialEq)]
17#[repr(transparent)]
18pub struct TropicalMinPlus<T: TropicalScalar>(pub T);
19
20impl<T: TropicalScalar> TropicalMinPlus<T> {
21 #[inline(always)]
23 pub fn new(value: T) -> Self {
24 Self(value)
25 }
26}
27
28impl<T: TropicalScalar> TropicalSemiring for TropicalMinPlus<T> {
29 type Scalar = T;
30
31 #[inline(always)]
32 fn tropical_zero() -> Self {
33 Self(T::pos_infinity())
34 }
35
36 #[inline(always)]
37 fn tropical_one() -> Self {
38 Self(T::scalar_zero())
39 }
40
41 #[inline(always)]
42 fn tropical_add(self, rhs: Self) -> Self {
43 Self(self.0.scalar_min(rhs.0))
44 }
45
46 #[inline(always)]
47 fn tropical_mul(self, rhs: Self) -> Self {
48 Self(self.0.scalar_add(rhs.0))
49 }
50
51 #[inline(always)]
52 fn value(&self) -> T {
53 self.0
54 }
55
56 #[inline(always)]
57 fn from_scalar(s: T) -> Self {
58 Self(s)
59 }
60}
61
62impl<T: TropicalScalar> TropicalWithArgmax for TropicalMinPlus<T> {
63 type Index = u32;
64
65 #[inline(always)]
66 fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
67 if self.0 <= rhs.0 {
69 (self, self_idx)
70 } else {
71 (rhs, rhs_idx)
72 }
73 }
74}
75
76impl<T: TropicalScalar> SimdTropical for TropicalMinPlus<T> {
77 const SIMD_AVAILABLE: bool = true;
78 const SIMD_WIDTH: usize = 8;
79}
80
81impl<T: TropicalScalar> Add for TropicalMinPlus<T> {
82 type Output = Self;
83
84 #[inline(always)]
85 fn add(self, rhs: Self) -> Self::Output {
86 self.tropical_add(rhs)
87 }
88}
89
90impl<T: TropicalScalar> Mul for TropicalMinPlus<T> {
91 type Output = Self;
92
93 #[inline(always)]
94 fn mul(self, rhs: Self) -> Self::Output {
95 self.tropical_mul(rhs)
96 }
97}
98
99impl<T: TropicalScalar> Default for TropicalMinPlus<T> {
100 #[inline(always)]
101 fn default() -> Self {
102 Self::tropical_zero()
103 }
104}
105
106impl<T: TropicalScalar> fmt::Debug for TropicalMinPlus<T> {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "TropicalMinPlus({})", self.0)
109 }
110}
111
112impl<T: TropicalScalar> fmt::Display for TropicalMinPlus<T> {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 write!(f, "{}", self.0)
115 }
116}
117
118impl<T: TropicalScalar> From<T> for TropicalMinPlus<T> {
119 #[inline(always)]
120 fn from(value: T) -> Self {
121 Self(value)
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_semiring_identity() {
131 let a = TropicalMinPlus::new(5.0f64);
132 let zero = TropicalMinPlus::tropical_zero();
133 let one = TropicalMinPlus::tropical_one();
134
135 assert_eq!(a.tropical_add(zero), a);
137 assert_eq!(a.tropical_mul(one), a);
139 }
140
141 #[test]
142 fn test_operations() {
143 let a = TropicalMinPlus::new(3.0f64);
144 let b = TropicalMinPlus::new(5.0f64);
145
146 assert_eq!(a.tropical_add(b).0, 3.0);
148 assert_eq!(a.tropical_mul(b).0, 8.0);
150 }
151
152 #[test]
153 fn test_shortest_path_scenario() {
154 let a = TropicalMinPlus::new(10.0f64);
156 let b = TropicalMinPlus::new(5.0f64);
157 assert_eq!(a.tropical_add(b).0, 5.0);
158
159 let path = TropicalMinPlus::new(5.0f64);
161 let edge = TropicalMinPlus::new(3.0f64);
162 assert_eq!(path.tropical_mul(edge).0, 8.0);
163 }
164
165 #[test]
166 fn test_argmin_right_wins() {
167 let a = TropicalMinPlus::new(5.0f64);
169 let b = TropicalMinPlus::new(3.0f64);
170
171 let (result, idx) = a.tropical_add_argmax(0, b, 1);
172 assert_eq!(result.0, 3.0);
173 assert_eq!(idx, 1); }
175
176 #[test]
177 fn test_argmin_left_wins() {
178 let a = TropicalMinPlus::new(2.0f64);
179 let b = TropicalMinPlus::new(7.0f64);
180
181 let (result, idx) = a.tropical_add_argmax(10, b, 20);
182 assert_eq!(result.0, 2.0);
183 assert_eq!(idx, 10); }
185
186 #[test]
187 fn test_argmin_equal_values() {
188 let a = TropicalMinPlus::new(5.0f64);
190 let b = TropicalMinPlus::new(5.0f64);
191
192 let (result, idx) = a.tropical_add_argmax(1, b, 2);
193 assert_eq!(result.0, 5.0);
194 assert_eq!(idx, 1); }
196
197 #[test]
198 fn test_argmin_chain() {
199 let mut acc = TropicalMinPlus::tropical_zero(); let mut idx = 0u32;
202
203 let values = [8.0, 3.0, 9.0, 5.0]; for (k, &val) in values.iter().enumerate() {
205 let candidate = TropicalMinPlus::new(val);
206 (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
207 }
208
209 assert_eq!(acc.0, 3.0);
210 assert_eq!(idx, 1); }
212
213 #[test]
214 fn test_argmin_pos_infinity() {
215 let a = TropicalMinPlus::tropical_zero(); let b = TropicalMinPlus::new(100.0f64);
217
218 let (result, idx) = a.tropical_add_argmax(0, b, 1);
219 assert_eq!(result.0, 100.0);
220 assert_eq!(idx, 1); }
222
223 #[test]
224 fn test_absorbing_zero() {
225 let a = TropicalMinPlus::new(5.0f64);
226 let zero = TropicalMinPlus::tropical_zero();
227
228 let result = a.tropical_mul(zero);
230 assert!(result.0.is_infinite() && result.0 > 0.0);
231 }
232
233 #[test]
234 fn test_operator_overloads() {
235 let a = TropicalMinPlus::new(3.0f64);
236 let b = TropicalMinPlus::new(5.0f64);
237
238 assert_eq!((a + b).0, 3.0);
240 assert_eq!((b + a).0, 3.0);
241
242 assert_eq!((a * b).0, 8.0);
244 assert_eq!((b * a).0, 8.0);
245 }
246
247 #[test]
248 fn test_default() {
249 let d = TropicalMinPlus::<f64>::default();
250 assert!(d.0.is_infinite() && d.0 > 0.0); assert_eq!(d, TropicalMinPlus::tropical_zero());
252 }
253
254 #[test]
255 fn test_display_debug() {
256 let a = TropicalMinPlus::new(5.0f64);
257
258 assert_eq!(format!("{}", a), "5");
259 assert_eq!(format!("{:?}", a), "TropicalMinPlus(5)");
260 }
261
262 #[test]
263 fn test_from() {
264 let a: TropicalMinPlus<f64> = 5.0.into();
265 assert_eq!(a.0, 5.0);
266
267 let b = TropicalMinPlus::<f64>::from(3.0);
268 assert_eq!(b.0, 3.0);
269 }
270
271 #[test]
272 fn test_value_and_from_scalar() {
273 let a = TropicalMinPlus::new(5.0f64);
274 assert_eq!(a.value(), 5.0);
275
276 let b = TropicalMinPlus::<f64>::from_scalar(3.0);
277 assert_eq!(b.value(), 3.0);
278 }
279
280 #[test]
281 fn test_simd_tropical() {
282 assert!(TropicalMinPlus::<f64>::SIMD_AVAILABLE);
283 assert_eq!(TropicalMinPlus::<f64>::SIMD_WIDTH, 8);
284 }
285
286 #[test]
287 fn test_clone_copy() {
288 let a = TropicalMinPlus::new(5.0f64);
289 let a_copy = a;
290 let a_clone = a.clone();
291
292 assert_eq!(a, a_copy);
293 assert_eq!(a, a_clone);
294 }
295
296 #[test]
297 fn test_eq() {
298 let a1 = TropicalMinPlus::new(5.0f64);
299 let a2 = TropicalMinPlus::new(5.0f64);
300 let b = TropicalMinPlus::new(3.0f64);
301
302 assert_eq!(a1, a2);
303 assert_ne!(a1, b);
304 }
305
306 #[test]
307 fn test_f32() {
308 let a = TropicalMinPlus::new(3.0f32);
309 let b = TropicalMinPlus::new(5.0f32);
310
311 assert!((a.tropical_add(b).0 - 3.0).abs() < 1e-6);
312 assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
313 }
314}