tropical_gemm/types/
scalar.rs

1use std::fmt::{Debug, Display};
2
3/// Trait for scalar types that can be used as underlying values in tropical numbers.
4pub trait TropicalScalar:
5    Copy + Clone + Send + Sync + Debug + Display + PartialOrd + 'static + Sized
6{
7    /// The additive identity (standard arithmetic).
8    fn scalar_zero() -> Self;
9
10    /// The multiplicative identity (standard arithmetic).
11    fn scalar_one() -> Self;
12
13    /// Standard arithmetic addition.
14    fn scalar_add(self, rhs: Self) -> Self;
15
16    /// Standard arithmetic multiplication.
17    fn scalar_mul(self, rhs: Self) -> Self;
18
19    /// Positive infinity (for MinPlus zero).
20    fn pos_infinity() -> Self;
21
22    /// Negative infinity (for MaxPlus zero).
23    fn neg_infinity() -> Self;
24
25    /// Maximum of two values.
26    fn scalar_max(self, rhs: Self) -> Self;
27
28    /// Minimum of two values.
29    fn scalar_min(self, rhs: Self) -> Self;
30}
31
32macro_rules! impl_tropical_scalar_float {
33    ($($t:ty),*) => {
34        $(
35            impl TropicalScalar for $t {
36                #[inline(always)]
37                fn scalar_zero() -> Self {
38                    0.0
39                }
40
41                #[inline(always)]
42                fn scalar_one() -> Self {
43                    1.0
44                }
45
46                #[inline(always)]
47                fn scalar_add(self, rhs: Self) -> Self {
48                    self + rhs
49                }
50
51                #[inline(always)]
52                fn scalar_mul(self, rhs: Self) -> Self {
53                    self * rhs
54                }
55
56                #[inline(always)]
57                fn pos_infinity() -> Self {
58                    <$t>::INFINITY
59                }
60
61                #[inline(always)]
62                fn neg_infinity() -> Self {
63                    <$t>::NEG_INFINITY
64                }
65
66                #[inline(always)]
67                fn scalar_max(self, rhs: Self) -> Self {
68                    if self >= rhs { self } else { rhs }
69                }
70
71                #[inline(always)]
72                fn scalar_min(self, rhs: Self) -> Self {
73                    if self <= rhs { self } else { rhs }
74                }
75            }
76        )*
77    };
78}
79
80macro_rules! impl_tropical_scalar_int {
81    ($($t:ty),*) => {
82        $(
83            impl TropicalScalar for $t {
84                #[inline(always)]
85                fn scalar_zero() -> Self {
86                    0
87                }
88
89                #[inline(always)]
90                fn scalar_one() -> Self {
91                    1
92                }
93
94                #[inline(always)]
95                fn scalar_add(self, rhs: Self) -> Self {
96                    self + rhs
97                }
98
99                #[inline(always)]
100                fn scalar_mul(self, rhs: Self) -> Self {
101                    self * rhs
102                }
103
104                #[inline(always)]
105                fn pos_infinity() -> Self {
106                    <$t>::MAX
107                }
108
109                #[inline(always)]
110                fn neg_infinity() -> Self {
111                    <$t>::MIN
112                }
113
114                #[inline(always)]
115                fn scalar_max(self, rhs: Self) -> Self {
116                    if self >= rhs { self } else { rhs }
117                }
118
119                #[inline(always)]
120                fn scalar_min(self, rhs: Self) -> Self {
121                    if self <= rhs { self } else { rhs }
122                }
123            }
124        )*
125    };
126}
127
128impl_tropical_scalar_float!(f32, f64);
129impl_tropical_scalar_int!(i32, i64, i8, i16, u8, u16, u32, u64);
130
131impl TropicalScalar for bool {
132    #[inline(always)]
133    fn scalar_zero() -> Self {
134        false
135    }
136
137    #[inline(always)]
138    fn scalar_one() -> Self {
139        true
140    }
141
142    #[inline(always)]
143    fn scalar_add(self, rhs: Self) -> Self {
144        self || rhs
145    }
146
147    #[inline(always)]
148    fn scalar_mul(self, rhs: Self) -> Self {
149        self && rhs
150    }
151
152    #[inline(always)]
153    fn pos_infinity() -> Self {
154        true
155    }
156
157    #[inline(always)]
158    fn neg_infinity() -> Self {
159        false
160    }
161
162    #[inline(always)]
163    fn scalar_max(self, rhs: Self) -> Self {
164        self || rhs
165    }
166
167    #[inline(always)]
168    fn scalar_min(self, rhs: Self) -> Self {
169        self && rhs
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_f64_scalar() {
179        assert_eq!(f64::scalar_zero(), 0.0);
180        assert_eq!(f64::scalar_one(), 1.0);
181        assert_eq!(3.0f64.scalar_add(5.0), 8.0);
182        assert_eq!(3.0f64.scalar_mul(5.0), 15.0);
183        assert!(f64::pos_infinity().is_infinite() && f64::pos_infinity() > 0.0);
184        assert!(f64::neg_infinity().is_infinite() && f64::neg_infinity() < 0.0);
185        assert_eq!(3.0f64.scalar_max(5.0), 5.0);
186        assert_eq!(3.0f64.scalar_min(5.0), 3.0);
187    }
188
189    #[test]
190    fn test_f32_scalar() {
191        assert_eq!(f32::scalar_zero(), 0.0);
192        assert_eq!(f32::scalar_one(), 1.0);
193        assert!((3.0f32.scalar_add(5.0) - 8.0).abs() < 1e-6);
194        assert!((3.0f32.scalar_mul(5.0) - 15.0).abs() < 1e-6);
195        assert!(f32::pos_infinity().is_infinite() && f32::pos_infinity() > 0.0);
196        assert!(f32::neg_infinity().is_infinite() && f32::neg_infinity() < 0.0);
197        assert!((3.0f32.scalar_max(5.0) - 5.0).abs() < 1e-6);
198        assert!((3.0f32.scalar_min(5.0) - 3.0).abs() < 1e-6);
199    }
200
201    #[test]
202    fn test_i32_scalar() {
203        assert_eq!(i32::scalar_zero(), 0);
204        assert_eq!(i32::scalar_one(), 1);
205        assert_eq!(3i32.scalar_add(5), 8);
206        assert_eq!(3i32.scalar_mul(5), 15);
207        assert_eq!(i32::pos_infinity(), i32::MAX);
208        assert_eq!(i32::neg_infinity(), i32::MIN);
209        assert_eq!(3i32.scalar_max(5), 5);
210        assert_eq!(3i32.scalar_min(5), 3);
211    }
212
213    #[test]
214    fn test_i64_scalar() {
215        assert_eq!(i64::scalar_zero(), 0);
216        assert_eq!(i64::scalar_one(), 1);
217        assert_eq!(3i64.scalar_add(5), 8);
218        assert_eq!(3i64.scalar_mul(5), 15);
219        assert_eq!(i64::pos_infinity(), i64::MAX);
220        assert_eq!(i64::neg_infinity(), i64::MIN);
221        assert_eq!(3i64.scalar_max(5), 5);
222        assert_eq!(3i64.scalar_min(5), 3);
223    }
224
225    #[test]
226    fn test_i8_scalar() {
227        assert_eq!(i8::scalar_zero(), 0);
228        assert_eq!(i8::scalar_one(), 1);
229        assert_eq!(3i8.scalar_add(5), 8);
230        assert_eq!(3i8.scalar_mul(5), 15);
231        assert_eq!(i8::pos_infinity(), i8::MAX);
232        assert_eq!(i8::neg_infinity(), i8::MIN);
233        assert_eq!(3i8.scalar_max(5), 5);
234        assert_eq!(3i8.scalar_min(5), 3);
235    }
236
237    #[test]
238    fn test_i16_scalar() {
239        assert_eq!(i16::scalar_zero(), 0);
240        assert_eq!(i16::scalar_one(), 1);
241        assert_eq!(3i16.scalar_add(5), 8);
242        assert_eq!(3i16.scalar_mul(5), 15);
243        assert_eq!(i16::pos_infinity(), i16::MAX);
244        assert_eq!(i16::neg_infinity(), i16::MIN);
245        assert_eq!(3i16.scalar_max(5), 5);
246        assert_eq!(3i16.scalar_min(5), 3);
247    }
248
249    #[test]
250    fn test_u8_scalar() {
251        assert_eq!(u8::scalar_zero(), 0);
252        assert_eq!(u8::scalar_one(), 1);
253        assert_eq!(3u8.scalar_add(5), 8);
254        assert_eq!(3u8.scalar_mul(5), 15);
255        assert_eq!(u8::pos_infinity(), u8::MAX);
256        assert_eq!(u8::neg_infinity(), u8::MIN);
257        assert_eq!(3u8.scalar_max(5), 5);
258        assert_eq!(3u8.scalar_min(5), 3);
259    }
260
261    #[test]
262    fn test_u16_scalar() {
263        assert_eq!(u16::scalar_zero(), 0);
264        assert_eq!(u16::scalar_one(), 1);
265        assert_eq!(3u16.scalar_add(5), 8);
266        assert_eq!(3u16.scalar_mul(5), 15);
267        assert_eq!(u16::pos_infinity(), u16::MAX);
268        assert_eq!(u16::neg_infinity(), u16::MIN);
269        assert_eq!(3u16.scalar_max(5), 5);
270        assert_eq!(3u16.scalar_min(5), 3);
271    }
272
273    #[test]
274    fn test_u32_scalar() {
275        assert_eq!(u32::scalar_zero(), 0);
276        assert_eq!(u32::scalar_one(), 1);
277        assert_eq!(3u32.scalar_add(5), 8);
278        assert_eq!(3u32.scalar_mul(5), 15);
279        assert_eq!(u32::pos_infinity(), u32::MAX);
280        assert_eq!(u32::neg_infinity(), u32::MIN);
281        assert_eq!(3u32.scalar_max(5), 5);
282        assert_eq!(3u32.scalar_min(5), 3);
283    }
284
285    #[test]
286    fn test_u64_scalar() {
287        assert_eq!(u64::scalar_zero(), 0);
288        assert_eq!(u64::scalar_one(), 1);
289        assert_eq!(3u64.scalar_add(5), 8);
290        assert_eq!(3u64.scalar_mul(5), 15);
291        assert_eq!(u64::pos_infinity(), u64::MAX);
292        assert_eq!(u64::neg_infinity(), u64::MIN);
293        assert_eq!(3u64.scalar_max(5), 5);
294        assert_eq!(3u64.scalar_min(5), 3);
295    }
296
297    #[test]
298    fn test_bool_scalar() {
299        assert!(!bool::scalar_zero());
300        assert!(bool::scalar_one());
301        // scalar_add is OR
302        assert!(true.scalar_add(false));
303        assert!(false.scalar_add(true));
304        assert!(!false.scalar_add(false));
305        assert!(true.scalar_add(true));
306        // scalar_mul is AND
307        assert!(!true.scalar_mul(false));
308        assert!(!false.scalar_mul(true));
309        assert!(!false.scalar_mul(false));
310        assert!(true.scalar_mul(true));
311        // pos_infinity is true, neg_infinity is false
312        assert!(bool::pos_infinity());
313        assert!(!bool::neg_infinity());
314        // scalar_max is OR
315        assert!(true.scalar_max(false));
316        assert!(!false.scalar_max(false));
317        // scalar_min is AND
318        assert!(!true.scalar_min(false));
319        assert!(true.scalar_min(true));
320    }
321
322    #[test]
323    fn test_float_edge_cases() {
324        // Test max/min with equal values
325        assert_eq!(5.0f64.scalar_max(5.0), 5.0);
326        assert_eq!(5.0f64.scalar_min(5.0), 5.0);
327        assert_eq!(5.0f32.scalar_max(5.0), 5.0);
328        assert_eq!(5.0f32.scalar_min(5.0), 5.0);
329    }
330
331    #[test]
332    fn test_int_edge_cases() {
333        // Test max/min with equal values
334        assert_eq!(5i32.scalar_max(5), 5);
335        assert_eq!(5i32.scalar_min(5), 5);
336        // Test with negative numbers
337        assert_eq!((-3i32).scalar_max(-5), -3);
338        assert_eq!((-3i32).scalar_min(-5), -5);
339    }
340}