tropical_gemm/types/
min_plus.rs1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6#[derive(Copy, Clone, PartialEq)]
17#[repr(transparent)]
18pub struct TropicalMinPlus<T: TropicalScalar>(pub T);
19
20impl<T: TropicalScalar> TropicalMinPlus<T> {
21 #[inline(always)]
23 pub fn new(value: T) -> Self {
24 Self(value)
25 }
26}
27
28impl<T: TropicalScalar> TropicalSemiring for TropicalMinPlus<T> {
29 type Scalar = T;
30
31 #[inline(always)]
32 fn tropical_zero() -> Self {
33 Self(T::pos_infinity())
34 }
35
36 #[inline(always)]
37 fn tropical_one() -> Self {
38 Self(T::scalar_zero())
39 }
40
41 #[inline(always)]
42 fn tropical_add(self, rhs: Self) -> Self {
43 Self(self.0.scalar_min(rhs.0))
44 }
45
46 #[inline(always)]
47 fn tropical_mul(self, rhs: Self) -> Self {
48 Self(self.0.scalar_add(rhs.0))
49 }
50
51 #[inline(always)]
52 fn value(&self) -> T {
53 self.0
54 }
55
56 #[inline(always)]
57 fn from_scalar(s: T) -> Self {
58 Self(s)
59 }
60}
61
62impl<T: TropicalScalar> TropicalWithArgmax for TropicalMinPlus<T> {
63 type Index = u32;
64
65 #[inline(always)]
66 fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
67 if self.0 <= rhs.0 {
69 (self, self_idx)
70 } else {
71 (rhs, rhs_idx)
72 }
73 }
74
75 #[inline(always)]
76 fn is_no_contribution(&self) -> bool {
77 self.0.is_drifted_pos_zero()
78 }
79}
80
81impl<T: TropicalScalar> SimdTropical for TropicalMinPlus<T> {
82 const SIMD_AVAILABLE: bool = true;
83 const SIMD_WIDTH: usize = 8;
84}
85
86impl<T: TropicalScalar> Add for TropicalMinPlus<T> {
87 type Output = Self;
88
89 #[inline(always)]
90 fn add(self, rhs: Self) -> Self::Output {
91 self.tropical_add(rhs)
92 }
93}
94
95impl<T: TropicalScalar> Mul for TropicalMinPlus<T> {
96 type Output = Self;
97
98 #[inline(always)]
99 fn mul(self, rhs: Self) -> Self::Output {
100 self.tropical_mul(rhs)
101 }
102}
103
104impl<T: TropicalScalar> Default for TropicalMinPlus<T> {
105 #[inline(always)]
106 fn default() -> Self {
107 Self::tropical_zero()
108 }
109}
110
111impl<T: TropicalScalar> fmt::Debug for TropicalMinPlus<T> {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 write!(f, "TropicalMinPlus({})", self.0)
114 }
115}
116
117impl<T: TropicalScalar> fmt::Display for TropicalMinPlus<T> {
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 write!(f, "{}", self.0)
120 }
121}
122
123impl<T: TropicalScalar> From<T> for TropicalMinPlus<T> {
124 #[inline(always)]
125 fn from(value: T) -> Self {
126 Self(value)
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_semiring_identity() {
136 let a = TropicalMinPlus::new(5.0f64);
137 let zero = TropicalMinPlus::tropical_zero();
138 let one = TropicalMinPlus::tropical_one();
139
140 assert_eq!(a.tropical_add(zero), a);
142 assert_eq!(a.tropical_mul(one), a);
144 }
145
146 #[test]
147 fn test_operations() {
148 let a = TropicalMinPlus::new(3.0f64);
149 let b = TropicalMinPlus::new(5.0f64);
150
151 assert_eq!(a.tropical_add(b).0, 3.0);
153 assert_eq!(a.tropical_mul(b).0, 8.0);
155 }
156
157 #[test]
158 fn test_shortest_path_scenario() {
159 let a = TropicalMinPlus::new(10.0f64);
161 let b = TropicalMinPlus::new(5.0f64);
162 assert_eq!(a.tropical_add(b).0, 5.0);
163
164 let path = TropicalMinPlus::new(5.0f64);
166 let edge = TropicalMinPlus::new(3.0f64);
167 assert_eq!(path.tropical_mul(edge).0, 8.0);
168 }
169
170 #[test]
171 fn test_argmin_right_wins() {
172 let a = TropicalMinPlus::new(5.0f64);
174 let b = TropicalMinPlus::new(3.0f64);
175
176 let (result, idx) = a.tropical_add_argmax(0, b, 1);
177 assert_eq!(result.0, 3.0);
178 assert_eq!(idx, 1); }
180
181 #[test]
182 fn test_argmin_left_wins() {
183 let a = TropicalMinPlus::new(2.0f64);
184 let b = TropicalMinPlus::new(7.0f64);
185
186 let (result, idx) = a.tropical_add_argmax(10, b, 20);
187 assert_eq!(result.0, 2.0);
188 assert_eq!(idx, 10); }
190
191 #[test]
192 fn test_argmin_equal_values() {
193 let a = TropicalMinPlus::new(5.0f64);
195 let b = TropicalMinPlus::new(5.0f64);
196
197 let (result, idx) = a.tropical_add_argmax(1, b, 2);
198 assert_eq!(result.0, 5.0);
199 assert_eq!(idx, 1); }
201
202 #[test]
203 fn test_argmin_chain() {
204 let mut acc = TropicalMinPlus::tropical_zero(); let mut idx = 0u32;
207
208 let values = [8.0, 3.0, 9.0, 5.0]; for (k, &val) in values.iter().enumerate() {
210 let candidate = TropicalMinPlus::new(val);
211 (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
212 }
213
214 assert_eq!(acc.0, 3.0);
215 assert_eq!(idx, 1); }
217
218 #[test]
219 fn test_argmin_pos_infinity() {
220 let a = TropicalMinPlus::tropical_zero(); let b = TropicalMinPlus::new(100.0f64);
222
223 let (result, idx) = a.tropical_add_argmax(0, b, 1);
224 assert_eq!(result.0, 100.0);
225 assert_eq!(idx, 1); }
227
228 #[test]
229 fn test_absorbing_zero() {
230 let a = TropicalMinPlus::new(5.0f64);
231 let zero = TropicalMinPlus::tropical_zero();
232
233 let result = a.tropical_mul(zero);
235 assert!(result.0.is_infinite() && result.0 > 0.0);
236 }
237
238 #[test]
239 fn test_operator_overloads() {
240 let a = TropicalMinPlus::new(3.0f64);
241 let b = TropicalMinPlus::new(5.0f64);
242
243 assert_eq!((a + b).0, 3.0);
245 assert_eq!((b + a).0, 3.0);
246
247 assert_eq!((a * b).0, 8.0);
249 assert_eq!((b * a).0, 8.0);
250 }
251
252 #[test]
253 fn test_default() {
254 let d = TropicalMinPlus::<f64>::default();
255 assert!(d.0.is_infinite() && d.0 > 0.0); assert_eq!(d, TropicalMinPlus::tropical_zero());
257 }
258
259 #[test]
260 fn test_display_debug() {
261 let a = TropicalMinPlus::new(5.0f64);
262
263 assert_eq!(format!("{}", a), "5");
264 assert_eq!(format!("{:?}", a), "TropicalMinPlus(5)");
265 }
266
267 #[test]
268 fn test_from() {
269 let a: TropicalMinPlus<f64> = 5.0.into();
270 assert_eq!(a.0, 5.0);
271
272 let b = TropicalMinPlus::<f64>::from(3.0);
273 assert_eq!(b.0, 3.0);
274 }
275
276 #[test]
277 fn test_value_and_from_scalar() {
278 let a = TropicalMinPlus::new(5.0f64);
279 assert_eq!(a.value(), 5.0);
280
281 let b = TropicalMinPlus::<f64>::from_scalar(3.0);
282 assert_eq!(b.value(), 3.0);
283 }
284
285 #[test]
286 fn test_simd_tropical() {
287 assert!(TropicalMinPlus::<f64>::SIMD_AVAILABLE);
288 assert_eq!(TropicalMinPlus::<f64>::SIMD_WIDTH, 8);
289 }
290
291 #[test]
292 fn test_clone_copy() {
293 let a = TropicalMinPlus::new(5.0f64);
294 let a_copy = a;
295 let a_clone = a.clone();
296
297 assert_eq!(a, a_copy);
298 assert_eq!(a, a_clone);
299 }
300
301 #[test]
302 fn test_eq() {
303 let a1 = TropicalMinPlus::new(5.0f64);
304 let a2 = TropicalMinPlus::new(5.0f64);
305 let b = TropicalMinPlus::new(3.0f64);
306
307 assert_eq!(a1, a2);
308 assert_ne!(a1, b);
309 }
310
311 #[test]
312 fn test_f32() {
313 let a = TropicalMinPlus::new(3.0f32);
314 let b = TropicalMinPlus::new(5.0f32);
315
316 assert!((a.tropical_add(b).0 - 3.0).abs() < 1e-6);
317 assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
318 }
319}