tropical_gemm/types/
max_plus.rs1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6#[derive(Copy, Clone, PartialEq)]
18#[repr(transparent)]
19pub struct TropicalMaxPlus<T: TropicalScalar>(pub T);
20
21impl<T: TropicalScalar> TropicalMaxPlus<T> {
22 #[inline(always)]
24 pub fn new(value: T) -> Self {
25 Self(value)
26 }
27}
28
29impl<T: TropicalScalar> TropicalSemiring for TropicalMaxPlus<T> {
30 type Scalar = T;
31
32 #[inline(always)]
33 fn tropical_zero() -> Self {
34 Self(T::neg_infinity())
35 }
36
37 #[inline(always)]
38 fn tropical_one() -> Self {
39 Self(T::scalar_zero())
40 }
41
42 #[inline(always)]
43 fn tropical_add(self, rhs: Self) -> Self {
44 Self(self.0.scalar_max(rhs.0))
45 }
46
47 #[inline(always)]
48 fn tropical_mul(self, rhs: Self) -> Self {
49 Self(self.0.scalar_add(rhs.0))
50 }
51
52 #[inline(always)]
53 fn value(&self) -> T {
54 self.0
55 }
56
57 #[inline(always)]
58 fn from_scalar(s: T) -> Self {
59 Self(s)
60 }
61}
62
63impl<T: TropicalScalar> TropicalWithArgmax for TropicalMaxPlus<T> {
64 type Index = u32;
65
66 #[inline(always)]
67 fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
68 if self.0 >= rhs.0 {
69 (self, self_idx)
70 } else {
71 (rhs, rhs_idx)
72 }
73 }
74}
75
76impl<T: TropicalScalar> SimdTropical for TropicalMaxPlus<T> {
77 const SIMD_AVAILABLE: bool = true;
78 const SIMD_WIDTH: usize = 8; }
80
81impl<T: TropicalScalar> Add for TropicalMaxPlus<T> {
82 type Output = Self;
83
84 #[inline(always)]
85 fn add(self, rhs: Self) -> Self::Output {
86 self.tropical_add(rhs)
87 }
88}
89
90impl<T: TropicalScalar> Mul for TropicalMaxPlus<T> {
91 type Output = Self;
92
93 #[inline(always)]
94 fn mul(self, rhs: Self) -> Self::Output {
95 self.tropical_mul(rhs)
96 }
97}
98
99impl<T: TropicalScalar> Default for TropicalMaxPlus<T> {
100 #[inline(always)]
101 fn default() -> Self {
102 Self::tropical_zero()
103 }
104}
105
106impl<T: TropicalScalar> fmt::Debug for TropicalMaxPlus<T> {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "TropicalMaxPlus({})", self.0)
109 }
110}
111
112impl<T: TropicalScalar> fmt::Display for TropicalMaxPlus<T> {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 write!(f, "{}", self.0)
115 }
116}
117
118impl<T: TropicalScalar> From<T> for TropicalMaxPlus<T> {
119 #[inline(always)]
120 fn from(value: T) -> Self {
121 Self(value)
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_semiring_identity() {
131 let a = TropicalMaxPlus::new(5.0f64);
132 let zero = TropicalMaxPlus::tropical_zero();
133 let one = TropicalMaxPlus::tropical_one();
134
135 assert_eq!(a.tropical_add(zero), a);
137 assert_eq!(a.tropical_mul(one), a);
139 }
140
141 #[test]
142 fn test_operations() {
143 let a = TropicalMaxPlus::new(3.0f64);
144 let b = TropicalMaxPlus::new(5.0f64);
145
146 assert_eq!(a.tropical_add(b).0, 5.0);
148 assert_eq!(a.tropical_mul(b).0, 8.0);
150 }
151
152 #[test]
153 fn test_argmax() {
154 let a = TropicalMaxPlus::new(3.0f64);
155 let b = TropicalMaxPlus::new(5.0f64);
156
157 let (result, idx) = a.tropical_add_argmax(0, b, 1);
158 assert_eq!(result.0, 5.0);
159 assert_eq!(idx, 1);
160 }
161
162 #[test]
163 fn test_argmax_left_wins() {
164 let a = TropicalMaxPlus::new(7.0f64);
165 let b = TropicalMaxPlus::new(3.0f64);
166
167 let (result, idx) = a.tropical_add_argmax(10, b, 20);
168 assert_eq!(result.0, 7.0);
169 assert_eq!(idx, 10); }
171
172 #[test]
173 fn test_argmax_equal_values() {
174 let a = TropicalMaxPlus::new(5.0f64);
176 let b = TropicalMaxPlus::new(5.0f64);
177
178 let (result, idx) = a.tropical_add_argmax(1, b, 2);
179 assert_eq!(result.0, 5.0);
180 assert_eq!(idx, 1); }
182
183 #[test]
184 fn test_argmax_chain() {
185 let mut acc = TropicalMaxPlus::tropical_zero();
187 let mut idx = 0u32;
188
189 let values = [3.0, 7.0, 2.0, 5.0]; for (k, &val) in values.iter().enumerate() {
191 let candidate = TropicalMaxPlus::new(val);
192 (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
193 }
194
195 assert_eq!(acc.0, 7.0);
196 assert_eq!(idx, 1); }
198
199 #[test]
200 fn test_argmax_neg_infinity() {
201 let a = TropicalMaxPlus::tropical_zero(); let b = TropicalMaxPlus::new(-100.0f64);
203
204 let (result, idx) = a.tropical_add_argmax(0, b, 1);
205 assert_eq!(result.0, -100.0);
206 assert_eq!(idx, 1); }
208
209 #[test]
210 fn test_absorbing_zero() {
211 let a = TropicalMaxPlus::new(5.0f64);
212 let zero = TropicalMaxPlus::tropical_zero();
213
214 let result = a.tropical_mul(zero);
217 assert!(result.0.is_infinite() && result.0 < 0.0);
218 }
219
220 #[test]
221 fn test_operator_overloads() {
222 let a = TropicalMaxPlus::new(3.0f64);
223 let b = TropicalMaxPlus::new(5.0f64);
224
225 assert_eq!((a + b).0, 5.0);
227 assert_eq!((b + a).0, 5.0);
228
229 assert_eq!((a * b).0, 8.0);
231 assert_eq!((b * a).0, 8.0);
232 }
233
234 #[test]
235 fn test_default() {
236 let d = TropicalMaxPlus::<f64>::default();
237 assert!(d.0.is_infinite() && d.0 < 0.0); assert_eq!(d, TropicalMaxPlus::tropical_zero());
239 }
240
241 #[test]
242 fn test_display_debug() {
243 let a = TropicalMaxPlus::new(5.0f64);
244
245 assert_eq!(format!("{}", a), "5");
246 assert_eq!(format!("{:?}", a), "TropicalMaxPlus(5)");
247 }
248
249 #[test]
250 fn test_from() {
251 let a: TropicalMaxPlus<f64> = 5.0.into();
252 assert_eq!(a.0, 5.0);
253
254 let b = TropicalMaxPlus::<f64>::from(3.0);
255 assert_eq!(b.0, 3.0);
256 }
257
258 #[test]
259 fn test_value_and_from_scalar() {
260 let a = TropicalMaxPlus::new(5.0f64);
261 assert_eq!(a.value(), 5.0);
262
263 let b = TropicalMaxPlus::<f64>::from_scalar(3.0);
264 assert_eq!(b.value(), 3.0);
265 }
266
267 #[test]
268 fn test_simd_tropical() {
269 assert!(TropicalMaxPlus::<f64>::SIMD_AVAILABLE);
270 assert_eq!(TropicalMaxPlus::<f64>::SIMD_WIDTH, 8);
271 }
272
273 #[test]
274 fn test_clone_copy() {
275 let a = TropicalMaxPlus::new(5.0f64);
276 let a_copy = a;
277 let a_clone = a.clone();
278
279 assert_eq!(a, a_copy);
280 assert_eq!(a, a_clone);
281 }
282
283 #[test]
284 fn test_eq() {
285 let a1 = TropicalMaxPlus::new(5.0f64);
286 let a2 = TropicalMaxPlus::new(5.0f64);
287 let b = TropicalMaxPlus::new(3.0f64);
288
289 assert_eq!(a1, a2);
290 assert_ne!(a1, b);
291 }
292
293 #[test]
294 fn test_f32() {
295 let a = TropicalMaxPlus::new(3.0f32);
296 let b = TropicalMaxPlus::new(5.0f32);
297
298 assert!((a.tropical_add(b).0 - 5.0).abs() < 1e-6);
299 assert!((a.tropical_mul(b).0 - 8.0).abs() < 1e-6);
300 }
301}