tropical_gemm/types/
max_mul.rs1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6#[derive(Copy, Clone, PartialEq)]
17#[repr(transparent)]
18pub struct TropicalMaxMul<T: TropicalScalar>(pub T);
19
20impl<T: TropicalScalar> TropicalMaxMul<T> {
21 #[inline(always)]
23 pub fn new(value: T) -> Self {
24 Self(value)
25 }
26}
27
28impl<T: TropicalScalar> TropicalSemiring for TropicalMaxMul<T> {
29 type Scalar = T;
30
31 #[inline(always)]
32 fn tropical_zero() -> Self {
33 Self(T::scalar_zero())
34 }
35
36 #[inline(always)]
37 fn tropical_one() -> Self {
38 Self(T::scalar_one())
39 }
40
41 #[inline(always)]
42 fn tropical_add(self, rhs: Self) -> Self {
43 Self(self.0.scalar_max(rhs.0))
44 }
45
46 #[inline(always)]
47 fn tropical_mul(self, rhs: Self) -> Self {
48 Self(self.0.scalar_mul(rhs.0))
49 }
50
51 #[inline(always)]
52 fn value(&self) -> T {
53 self.0
54 }
55
56 #[inline(always)]
57 fn from_scalar(s: T) -> Self {
58 Self(s)
59 }
60}
61
62impl<T: TropicalScalar> TropicalWithArgmax for TropicalMaxMul<T> {
63 type Index = u32;
64
65 #[inline(always)]
66 fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
67 if self.0 >= rhs.0 {
68 (self, self_idx)
69 } else {
70 (rhs, rhs_idx)
71 }
72 }
73}
74
75impl<T: TropicalScalar> SimdTropical for TropicalMaxMul<T> {
76 const SIMD_AVAILABLE: bool = true;
77 const SIMD_WIDTH: usize = 8;
78}
79
80impl<T: TropicalScalar> Add for TropicalMaxMul<T> {
81 type Output = Self;
82
83 #[inline(always)]
84 fn add(self, rhs: Self) -> Self::Output {
85 self.tropical_add(rhs)
86 }
87}
88
89impl<T: TropicalScalar> Mul for TropicalMaxMul<T> {
90 type Output = Self;
91
92 #[inline(always)]
93 fn mul(self, rhs: Self) -> Self::Output {
94 self.tropical_mul(rhs)
95 }
96}
97
98impl<T: TropicalScalar> Default for TropicalMaxMul<T> {
99 #[inline(always)]
100 fn default() -> Self {
101 Self::tropical_zero()
102 }
103}
104
105impl<T: TropicalScalar> fmt::Debug for TropicalMaxMul<T> {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 write!(f, "TropicalMaxMul({})", self.0)
108 }
109}
110
111impl<T: TropicalScalar> fmt::Display for TropicalMaxMul<T> {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 write!(f, "{}", self.0)
114 }
115}
116
117impl<T: TropicalScalar> From<T> for TropicalMaxMul<T> {
118 #[inline(always)]
119 fn from(value: T) -> Self {
120 Self(value)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_semiring_identity() {
130 let a = TropicalMaxMul::new(5.0f64);
131 let zero = TropicalMaxMul::tropical_zero();
132 let one = TropicalMaxMul::tropical_one();
133
134 assert_eq!(a.tropical_add(zero), a);
136 assert_eq!(a.tropical_mul(one), a);
138 }
139
140 #[test]
141 fn test_operations() {
142 let a = TropicalMaxMul::new(3.0f64);
143 let b = TropicalMaxMul::new(5.0f64);
144
145 assert_eq!(a.tropical_add(b).0, 5.0);
147 assert_eq!(a.tropical_mul(b).0, 15.0);
149 }
150
151 #[test]
152 fn test_absorbing_zero() {
153 let a = TropicalMaxMul::new(5.0f64);
154 let zero = TropicalMaxMul::tropical_zero();
155
156 assert_eq!(a.tropical_mul(zero), zero);
158 }
159
160 #[test]
161 fn test_operator_overloads() {
162 let a = TropicalMaxMul::new(3.0f64);
163 let b = TropicalMaxMul::new(5.0f64);
164
165 assert_eq!((a + b).0, 5.0);
167 assert_eq!((b + a).0, 5.0);
168
169 assert_eq!((a * b).0, 15.0);
171 assert_eq!((b * a).0, 15.0);
172 }
173
174 #[test]
175 fn test_default() {
176 let d = TropicalMaxMul::<f64>::default();
177 assert_eq!(d.0, 0.0); assert_eq!(d, TropicalMaxMul::tropical_zero());
179 }
180
181 #[test]
182 fn test_display_debug() {
183 let a = TropicalMaxMul::new(5.0f64);
184
185 assert_eq!(format!("{}", a), "5");
186 assert_eq!(format!("{:?}", a), "TropicalMaxMul(5)");
187 }
188
189 #[test]
190 fn test_from() {
191 let a: TropicalMaxMul<f64> = 5.0.into();
192 assert_eq!(a.0, 5.0);
193
194 let b = TropicalMaxMul::<f64>::from(3.0);
195 assert_eq!(b.0, 3.0);
196 }
197
198 #[test]
199 fn test_value_and_from_scalar() {
200 let a = TropicalMaxMul::new(5.0f64);
201 assert_eq!(a.value(), 5.0);
202
203 let b = TropicalMaxMul::<f64>::from_scalar(3.0);
204 assert_eq!(b.value(), 3.0);
205 }
206
207 #[test]
208 fn test_argmax_self_wins() {
209 let a = TropicalMaxMul::new(7.0f64);
210 let b = TropicalMaxMul::new(3.0f64);
211
212 let (result, idx) = a.tropical_add_argmax(1, b, 2);
213 assert_eq!(result.0, 7.0);
214 assert_eq!(idx, 1);
215 }
216
217 #[test]
218 fn test_argmax_rhs_wins() {
219 let a = TropicalMaxMul::new(3.0f64);
220 let b = TropicalMaxMul::new(7.0f64);
221
222 let (result, idx) = a.tropical_add_argmax(1, b, 2);
223 assert_eq!(result.0, 7.0);
224 assert_eq!(idx, 2);
225 }
226
227 #[test]
228 fn test_argmax_equal_self_wins() {
229 let a = TropicalMaxMul::new(5.0f64);
231 let b = TropicalMaxMul::new(5.0f64);
232
233 let (result, idx) = a.tropical_add_argmax(1, b, 2);
234 assert_eq!(result.0, 5.0);
235 assert_eq!(idx, 1);
236 }
237
238 #[test]
239 fn test_argmax_chain() {
240 let mut acc = TropicalMaxMul::tropical_zero();
241 let mut idx = 0u32;
242
243 let values = [3.0, 7.0, 2.0, 5.0]; for (k, &val) in values.iter().enumerate() {
245 let candidate = TropicalMaxMul::new(val);
246 (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
247 }
248
249 assert_eq!(acc.0, 7.0);
250 assert_eq!(idx, 1);
251 }
252
253 #[test]
254 fn test_simd_tropical() {
255 assert!(TropicalMaxMul::<f64>::SIMD_AVAILABLE);
256 assert_eq!(TropicalMaxMul::<f64>::SIMD_WIDTH, 8);
257 }
258
259 #[test]
260 fn test_clone_copy() {
261 let a = TropicalMaxMul::new(5.0f64);
262 let a_copy = a;
263 let a_clone = a.clone();
264
265 assert_eq!(a, a_copy);
266 assert_eq!(a, a_clone);
267 }
268
269 #[test]
270 fn test_eq() {
271 let a1 = TropicalMaxMul::new(5.0f64);
272 let a2 = TropicalMaxMul::new(5.0f64);
273 let b = TropicalMaxMul::new(3.0f64);
274
275 assert_eq!(a1, a2);
276 assert_ne!(a1, b);
277 }
278
279 #[test]
280 fn test_f32() {
281 let a = TropicalMaxMul::new(3.0f32);
282 let b = TropicalMaxMul::new(5.0f32);
283
284 assert!((a.tropical_add(b).0 - 5.0).abs() < 1e-6);
285 assert!((a.tropical_mul(b).0 - 15.0).abs() < 1e-6);
286 }
287
288 #[test]
289 fn test_fuzzy_logic_example() {
290 let high = TropicalMaxMul::new(0.9f64);
292 let medium = TropicalMaxMul::new(0.5f64);
293 let low = TropicalMaxMul::new(0.2f64);
294
295 let or_result = high.tropical_add(low);
297 assert_eq!(or_result.0, 0.9);
298
299 let and_result = high.tropical_mul(medium);
301 assert!((and_result.0 - 0.45).abs() < 1e-10);
302 }
303}