tropical_gemm/types/
max_plus.rs1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6#[derive(Copy, Clone, PartialEq)]
18#[repr(transparent)]
19pub struct TropicalMaxPlus<T: TropicalScalar>(pub T);
20
21impl<T: TropicalScalar> TropicalMaxPlus<T> {
22 #[inline(always)]
24 pub fn new(value: T) -> Self {
25 Self(value)
26 }
27}
28
29impl<T: TropicalScalar> TropicalSemiring for TropicalMaxPlus<T> {
30 type Scalar = T;
31
32 #[inline(always)]
33 fn tropical_zero() -> Self {
34 Self(T::neg_infinity())
35 }
36
37 #[inline(always)]
38 fn tropical_one() -> Self {
39 Self(T::scalar_zero())
40 }
41
42 #[inline(always)]
43 fn tropical_add(self, rhs: Self) -> Self {
44 Self(self.0.scalar_max(rhs.0))
45 }
46
47 #[inline(always)]
48 fn tropical_mul(self, rhs: Self) -> Self {
49 Self(self.0.scalar_add(rhs.0))
50 }
51
52 #[inline(always)]
53 fn value(&self) -> T {
54 self.0
55 }
56
57 #[inline(always)]
58 fn from_scalar(s: T) -> Self {
59 Self(s)
60 }
61}
62
63impl<T: TropicalScalar> TropicalWithArgmax for TropicalMaxPlus<T> {
64 type Index = u32;
65
66 #[inline(always)]
67 fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
68 if self.0 >= rhs.0 {
69 (self, self_idx)
70 } else {
71 (rhs, rhs_idx)
72 }
73 }
74
75 #[inline(always)]
76 fn is_no_contribution(&self) -> bool {
77 self.0.is_drifted_neg_zero()
78 }
79}
80
81impl<T: TropicalScalar> SimdTropical for TropicalMaxPlus<T> {
82 const SIMD_AVAILABLE: bool = true;
83 const SIMD_WIDTH: usize = 8; }
85
86impl<T: TropicalScalar> Add for TropicalMaxPlus<T> {
87 type Output = Self;
88
89 #[inline(always)]
90 fn add(self, rhs: Self) -> Self::Output {
91 self.tropical_add(rhs)
92 }
93}
94
95impl<T: TropicalScalar> Mul for TropicalMaxPlus<T> {
96 type Output = Self;
97
98 #[inline(always)]
99 fn mul(self, rhs: Self) -> Self::Output {
100 self.tropical_mul(rhs)
101 }
102}
103
104impl<T: TropicalScalar> Default for TropicalMaxPlus<T> {
105 #[inline(always)]
106 fn default() -> Self {
107 Self::tropical_zero()
108 }
109}
110
111impl<T: TropicalScalar> fmt::Debug for TropicalMaxPlus<T> {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 write!(f, "TropicalMaxPlus({})", self.0)
114 }
115}
116
117impl<T: TropicalScalar> fmt::Display for TropicalMaxPlus<T> {
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 write!(f, "{}", self.0)
120 }
121}
122
123impl<T: TropicalScalar> From<T> for TropicalMaxPlus<T> {
124 #[inline(always)]
125 fn from(value: T) -> Self {
126 Self(value)
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_semiring_identity() {
136 let a = TropicalMaxPlus::new(5.0f64);
137 let zero = TropicalMaxPlus::tropical_zero();
138 let one = TropicalMaxPlus::tropical_one();
139
140 assert_eq!(a.tropical_add(zero), a);
142 assert_eq!(a.tropical_mul(one), a);
144 }
145
146 #[test]
147 fn test_operations() {
148 let a = TropicalMaxPlus::new(3.0f64);
149 let b = TropicalMaxPlus::new(5.0f64);
150
151 assert_eq!(a.tropical_add(b).0, 5.0);
153 assert_eq!(a.tropical_mul(b).0, 8.0);
155 }
156
157 #[test]
158 fn test_argmax() {
159 let a = TropicalMaxPlus::new(3.0f64);
160 let b = TropicalMaxPlus::new(5.0f64);
161
162 let (result, idx) = a.tropical_add_argmax(0, b, 1);
163 assert_eq!(result.0, 5.0);
164 assert_eq!(idx, 1);
165 }
166
167 #[test]
168 fn test_argmax_left_wins() {
169 let a = TropicalMaxPlus::new(7.0f64);
170 let b = TropicalMaxPlus::new(3.0f64);
171
172 let (result, idx) = a.tropical_add_argmax(10, b, 20);
173 assert_eq!(result.0, 7.0);
174 assert_eq!(idx, 10); }
176
177 #[test]
178 fn test_argmax_equal_values() {
179 let a = TropicalMaxPlus::new(5.0f64);
181 let b = TropicalMaxPlus::new(5.0f64);
182
183 let (result, idx) = a.tropical_add_argmax(1, b, 2);
184 assert_eq!(result.0, 5.0);
185 assert_eq!(idx, 1); }
187
188 #[test]
189 fn test_argmax_chain() {
190 let mut acc = TropicalMaxPlus::tropical_zero();
192 let mut idx = 0u32;
193
194 let values = [3.0, 7.0, 2.0, 5.0]; for (k, &val) in values.iter().enumerate() {
196 let candidate = TropicalMaxPlus::new(val);
197 (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
198 }
199
200 assert_eq!(acc.0, 7.0);
201 assert_eq!(idx, 1); }
203
204 #[test]
205 fn test_argmax_neg_infinity() {
206 let a = TropicalMaxPlus::tropical_zero(); let b = TropicalMaxPlus::new(-100.0f64);
208
209 let (result, idx) = a.tropical_add_argmax(0, b, 1);
210 assert_eq!(result.0, -100.0);
211 assert_eq!(idx, 1); }
213
214 #[test]
215 fn test_absorbing_zero() {
216 let a = TropicalMaxPlus::new(5.0f64);
217 let zero = TropicalMaxPlus::tropical_zero();
218
219 let result = a.tropical_mul(zero);
222 assert!(result.0.is_infinite() && result.0 < 0.0);
223 }
224
225 #[test]
226 fn test_operator_overloads() {
227 let a = TropicalMaxPlus::new(3.0f64);
228 let b = TropicalMaxPlus::new(5.0f64);
229
230 assert_eq!((a + b).0, 5.0);
232 assert_eq!((b + a).0, 5.0);
233
234 assert_eq!((a * b).0, 8.0);
236 assert_eq!((b * a).0, 8.0);
237 }
238
239 #[test]
240 fn test_default() {
241 let d = TropicalMaxPlus::<f64>::default();
242 assert!(d.0.is_infinite() && d.0 < 0.0); assert_eq!(d, TropicalMaxPlus::tropical_zero());
244 }
245
246 #[test]
247 fn test_display_debug() {
248 let a = TropicalMaxPlus::new(5.0f64);
249
250 assert_eq!(format!("{}", a), "5");
251 assert_eq!(format!("{:?}", a), "TropicalMaxPlus(5)");
252 }
253
254 #[test]
255 fn test_from() {
256 let a: TropicalMaxPlus<f64> = 5.0.into();
257 assert_eq!(a.0, 5.0);
258
259 let b = TropicalMaxPlus::<f64>::from(3.0);
260 assert_eq!(b.0, 3.0);
261 }
262
263 #[test]
264 fn test_value_and_from_scalar() {
265 let a = TropicalMaxPlus::new(5.0f64);
266 assert_eq!(a.value(), 5.0);
267
268 let b = TropicalMaxPlus::<f64>::from_scalar(3.0);
269 assert_eq!(b.value(), 3.0);
270 }
271
272 #[test]
273 fn test_simd_tropical() {
274 assert!(TropicalMaxPlus::<f64>::SIMD_AVAILABLE);
275 assert_eq!(TropicalMaxPlus::<f64>::SIMD_WIDTH, 8);
276 }
277
278 #[test]
279 fn test_clone_copy() {
280 let a = TropicalMaxPlus::new(5.0f64);
281 let a_copy = a;
282 let a_clone = a.clone();
283
284 assert_eq!(a, a_copy);
285 assert_eq!(a, a_clone);
286 }
287
288 #[test]
289 fn test_eq() {
290 let a1 = TropicalMaxPlus::new(5.0f64);
291 let a2 = TropicalMaxPlus::new(5.0f64);
292 let b = TropicalMaxPlus::new(3.0f64);
293
294 assert_eq!(a1, a2);
295 assert_ne!(a1, b);
296 }
297
298 #[test]
299 fn test_f32() {
300 let a = TropicalMaxPlus::new(3.0f32);
301 let b = TropicalMaxPlus::new(5.0f32);
302
303 assert!((a.tropical_add(b).0 - 5.0).abs() < 1e-6);
304 assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
305 }
306}