tropical_gemm/types/
max_mul.rs

1use super::scalar::TropicalScalar;
2use super::traits::{SimdTropical, TropicalSemiring, TropicalWithArgmax};
3use std::fmt;
4use std::ops::{Add, Mul};
5
6/// TropicalMaxMul semiring: (ℝ⁺, max, ×)
7///
8/// - Addition (⊕) = max
9/// - Multiplication (⊗) = ×
10/// - Zero = 0
11/// - One = 1
12///
13/// This is used for:
14/// - Probability computations (non-log space)
15/// - Fuzzy logic with product t-norm
16#[derive(Copy, Clone, PartialEq)]
17#[repr(transparent)]
18pub struct TropicalMaxMul<T: TropicalScalar>(pub T);
19
20impl<T: TropicalScalar> TropicalMaxMul<T> {
21    /// Create a new TropicalMaxMul value.
22    #[inline(always)]
23    pub fn new(value: T) -> Self {
24        Self(value)
25    }
26}
27
28impl<T: TropicalScalar> TropicalSemiring for TropicalMaxMul<T> {
29    type Scalar = T;
30
31    #[inline(always)]
32    fn tropical_zero() -> Self {
33        Self(T::scalar_zero())
34    }
35
36    #[inline(always)]
37    fn tropical_one() -> Self {
38        Self(T::scalar_one())
39    }
40
41    #[inline(always)]
42    fn tropical_add(self, rhs: Self) -> Self {
43        Self(self.0.scalar_max(rhs.0))
44    }
45
46    #[inline(always)]
47    fn tropical_mul(self, rhs: Self) -> Self {
48        Self(self.0.scalar_mul(rhs.0))
49    }
50
51    #[inline(always)]
52    fn value(&self) -> T {
53        self.0
54    }
55
56    #[inline(always)]
57    fn from_scalar(s: T) -> Self {
58        Self(s)
59    }
60}
61
62impl<T: TropicalScalar> TropicalWithArgmax for TropicalMaxMul<T> {
63    type Index = u32;
64
65    #[inline(always)]
66    fn tropical_add_argmax(self, self_idx: u32, rhs: Self, rhs_idx: u32) -> (Self, u32) {
67        if self.0 >= rhs.0 {
68            (self, self_idx)
69        } else {
70            (rhs, rhs_idx)
71        }
72    }
73}
74
75impl<T: TropicalScalar> SimdTropical for TropicalMaxMul<T> {
76    const SIMD_AVAILABLE: bool = true;
77    const SIMD_WIDTH: usize = 8;
78}
79
80impl<T: TropicalScalar> Add for TropicalMaxMul<T> {
81    type Output = Self;
82
83    #[inline(always)]
84    fn add(self, rhs: Self) -> Self::Output {
85        self.tropical_add(rhs)
86    }
87}
88
89impl<T: TropicalScalar> Mul for TropicalMaxMul<T> {
90    type Output = Self;
91
92    #[inline(always)]
93    fn mul(self, rhs: Self) -> Self::Output {
94        self.tropical_mul(rhs)
95    }
96}
97
98impl<T: TropicalScalar> Default for TropicalMaxMul<T> {
99    #[inline(always)]
100    fn default() -> Self {
101        Self::tropical_zero()
102    }
103}
104
105impl<T: TropicalScalar> fmt::Debug for TropicalMaxMul<T> {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        write!(f, "TropicalMaxMul({})", self.0)
108    }
109}
110
111impl<T: TropicalScalar> fmt::Display for TropicalMaxMul<T> {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        write!(f, "{}", self.0)
114    }
115}
116
117impl<T: TropicalScalar> From<T> for TropicalMaxMul<T> {
118    #[inline(always)]
119    fn from(value: T) -> Self {
120        Self(value)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_semiring_identity() {
130        let a = TropicalMaxMul::new(5.0f64);
131        let zero = TropicalMaxMul::tropical_zero();
132        let one = TropicalMaxMul::tropical_one();
133
134        // a ⊕ 0 = a
135        assert_eq!(a.tropical_add(zero), a);
136        // a ⊗ 1 = a
137        assert_eq!(a.tropical_mul(one), a);
138    }
139
140    #[test]
141    fn test_operations() {
142        let a = TropicalMaxMul::new(3.0f64);
143        let b = TropicalMaxMul::new(5.0f64);
144
145        // max(3, 5) = 5
146        assert_eq!(a.tropical_add(b).0, 5.0);
147        // 3 * 5 = 15
148        assert_eq!(a.tropical_mul(b).0, 15.0);
149    }
150
151    #[test]
152    fn test_absorbing_zero() {
153        let a = TropicalMaxMul::new(5.0f64);
154        let zero = TropicalMaxMul::tropical_zero();
155
156        // a ⊗ 0 = 0
157        assert_eq!(a.tropical_mul(zero), zero);
158    }
159
160    #[test]
161    fn test_operator_overloads() {
162        let a = TropicalMaxMul::new(3.0f64);
163        let b = TropicalMaxMul::new(5.0f64);
164
165        // Add operator (max)
166        assert_eq!((a + b).0, 5.0);
167        assert_eq!((b + a).0, 5.0);
168
169        // Mul operator (product)
170        assert_eq!((a * b).0, 15.0);
171        assert_eq!((b * a).0, 15.0);
172    }
173
174    #[test]
175    fn test_default() {
176        let d = TropicalMaxMul::<f64>::default();
177        assert_eq!(d.0, 0.0); // Zero is 0 for MaxMul
178        assert_eq!(d, TropicalMaxMul::tropical_zero());
179    }
180
181    #[test]
182    fn test_display_debug() {
183        let a = TropicalMaxMul::new(5.0f64);
184
185        assert_eq!(format!("{}", a), "5");
186        assert_eq!(format!("{:?}", a), "TropicalMaxMul(5)");
187    }
188
189    #[test]
190    fn test_from() {
191        let a: TropicalMaxMul<f64> = 5.0.into();
192        assert_eq!(a.0, 5.0);
193
194        let b = TropicalMaxMul::<f64>::from(3.0);
195        assert_eq!(b.0, 3.0);
196    }
197
198    #[test]
199    fn test_value_and_from_scalar() {
200        let a = TropicalMaxMul::new(5.0f64);
201        assert_eq!(a.value(), 5.0);
202
203        let b = TropicalMaxMul::<f64>::from_scalar(3.0);
204        assert_eq!(b.value(), 3.0);
205    }
206
207    #[test]
208    fn test_argmax_self_wins() {
209        let a = TropicalMaxMul::new(7.0f64);
210        let b = TropicalMaxMul::new(3.0f64);
211
212        let (result, idx) = a.tropical_add_argmax(1, b, 2);
213        assert_eq!(result.0, 7.0);
214        assert_eq!(idx, 1);
215    }
216
217    #[test]
218    fn test_argmax_rhs_wins() {
219        let a = TropicalMaxMul::new(3.0f64);
220        let b = TropicalMaxMul::new(7.0f64);
221
222        let (result, idx) = a.tropical_add_argmax(1, b, 2);
223        assert_eq!(result.0, 7.0);
224        assert_eq!(idx, 2);
225    }
226
227    #[test]
228    fn test_argmax_equal_self_wins() {
229        // When equal, self wins (>= comparison)
230        let a = TropicalMaxMul::new(5.0f64);
231        let b = TropicalMaxMul::new(5.0f64);
232
233        let (result, idx) = a.tropical_add_argmax(1, b, 2);
234        assert_eq!(result.0, 5.0);
235        assert_eq!(idx, 1);
236    }
237
238    #[test]
239    fn test_argmax_chain() {
240        let mut acc = TropicalMaxMul::tropical_zero();
241        let mut idx = 0u32;
242
243        let values = [3.0, 7.0, 2.0, 5.0]; // Max at index 1
244        for (k, &val) in values.iter().enumerate() {
245            let candidate = TropicalMaxMul::new(val);
246            (acc, idx) = acc.tropical_add_argmax(idx, candidate, k as u32);
247        }
248
249        assert_eq!(acc.0, 7.0);
250        assert_eq!(idx, 1);
251    }
252
253    #[test]
254    fn test_simd_tropical() {
255        assert!(TropicalMaxMul::<f64>::SIMD_AVAILABLE);
256        assert_eq!(TropicalMaxMul::<f64>::SIMD_WIDTH, 8);
257    }
258
259    #[test]
260    fn test_clone_copy() {
261        let a = TropicalMaxMul::new(5.0f64);
262        let a_copy = a;
263        let a_clone = a.clone();
264
265        assert_eq!(a, a_copy);
266        assert_eq!(a, a_clone);
267    }
268
269    #[test]
270    fn test_eq() {
271        let a1 = TropicalMaxMul::new(5.0f64);
272        let a2 = TropicalMaxMul::new(5.0f64);
273        let b = TropicalMaxMul::new(3.0f64);
274
275        assert_eq!(a1, a2);
276        assert_ne!(a1, b);
277    }
278
279    #[test]
280    fn test_f32() {
281        let a = TropicalMaxMul::new(3.0f32);
282        let b = TropicalMaxMul::new(5.0f32);
283
284        assert!((a.tropical_add(b).0 - 5.0).abs() < 1e-6);
285        assert!((a.tropical_mul(b).0 - 15.0).abs() < 1e-6);
286    }
287
288    #[test]
289    fn test_fuzzy_logic_example() {
290        // Fuzzy AND (product) and OR (max)
291        let high = TropicalMaxMul::new(0.9f64);
292        let medium = TropicalMaxMul::new(0.5f64);
293        let low = TropicalMaxMul::new(0.2f64);
294
295        // Fuzzy OR of high and low
296        let or_result = high.tropical_add(low);
297        assert_eq!(or_result.0, 0.9);
298
299        // Fuzzy AND of high and medium (product t-norm)
300        let and_result = high.tropical_mul(medium);
301        assert!((and_result.0 - 0.45).abs() < 1e-10);
302    }
303}