Skip to main content

tropical_gemm/types/
scalar.rs

1use std::fmt::{Debug, Display};
2
3/// Trait for scalar types that can be used as underlying values in tropical numbers.
4pub trait TropicalScalar:
5    Copy + Clone + Send + Sync + Debug + Display + PartialOrd + 'static + Sized
6{
7    /// The additive identity (standard arithmetic).
8    fn scalar_zero() -> Self;
9
10    /// The multiplicative identity (standard arithmetic).
11    fn scalar_one() -> Self;
12
13    /// Standard arithmetic addition.
14    fn scalar_add(self, rhs: Self) -> Self;
15
16    /// Standard arithmetic multiplication.
17    fn scalar_mul(self, rhs: Self) -> Self;
18
19    /// Positive infinity (for MinPlus zero).
20    fn pos_infinity() -> Self;
21
22    /// Negative infinity (for MaxPlus zero).
23    fn neg_infinity() -> Self;
24
25    /// Maximum of two values.
26    fn scalar_max(self, rhs: Self) -> Self;
27
28    /// Minimum of two values.
29    fn scalar_min(self, rhs: Self) -> Self;
30
31    /// Whether `self` is an in-band integer sentinel that has *drifted* off the
32    /// canonical negative tropical zero (`-∞`).
33    ///
34    /// Integer tropical zeros use a finite sentinel plus a guard-free `+`, so a
35    /// no-contribution cell's value lands in "infinity territory" (past
36    /// `neg_infinity() / 2`) without being exactly the sentinel. This is used at
37    /// GEMM write-back to canonicalize the argmax index of such cells.
38    ///
39    /// Exact-infinity representations (floats) never drift, and narrow / unsigned
40    /// integers are out of the headroom-sentinel scheme, so the default is
41    /// `false` — only `i32`/`i64` override it, which lets the canonicalization
42    /// branch fold away entirely for every other monomorphization.
43    #[inline(always)]
44    fn is_drifted_neg_zero(self) -> bool {
45        false
46    }
47
48    /// Positive (`+∞`) counterpart of [`TropicalScalar::is_drifted_neg_zero`].
49    #[inline(always)]
50    fn is_drifted_pos_zero(self) -> bool {
51        false
52    }
53}
54
55macro_rules! impl_tropical_scalar_float {
56    ($($t:ty),*) => {
57        $(
58            impl TropicalScalar for $t {
59                #[inline(always)]
60                fn scalar_zero() -> Self {
61                    0.0
62                }
63
64                #[inline(always)]
65                fn scalar_one() -> Self {
66                    1.0
67                }
68
69                #[inline(always)]
70                fn scalar_add(self, rhs: Self) -> Self {
71                    self + rhs
72                }
73
74                #[inline(always)]
75                fn scalar_mul(self, rhs: Self) -> Self {
76                    self * rhs
77                }
78
79                #[inline(always)]
80                fn pos_infinity() -> Self {
81                    <$t>::INFINITY
82                }
83
84                #[inline(always)]
85                fn neg_infinity() -> Self {
86                    <$t>::NEG_INFINITY
87                }
88
89                #[inline(always)]
90                fn scalar_max(self, rhs: Self) -> Self {
91                    if self >= rhs { self } else { rhs }
92                }
93
94                #[inline(always)]
95                fn scalar_min(self, rhs: Self) -> Self {
96                    if self <= rhs { self } else { rhs }
97                }
98            }
99        )*
100    };
101}
102
103/// The arithmetic + ordering methods shared by every integer `TropicalScalar`
104/// impl, regardless of how it represents the tropical zero (`MIN`/`MAX` for the
105/// narrow types, a headroom sentinel for `i32`/`i64`). The integer literals `0`
106/// and `1` infer to `Self` at each concrete instantiation.
107macro_rules! tropical_int_common {
108    () => {
109        #[inline(always)]
110        fn scalar_zero() -> Self {
111            0
112        }
113
114        #[inline(always)]
115        fn scalar_one() -> Self {
116            1
117        }
118
119        #[inline(always)]
120        fn scalar_add(self, rhs: Self) -> Self {
121            self + rhs
122        }
123
124        #[inline(always)]
125        fn scalar_mul(self, rhs: Self) -> Self {
126            self * rhs
127        }
128
129        #[inline(always)]
130        fn scalar_max(self, rhs: Self) -> Self {
131            if self >= rhs {
132                self
133            } else {
134                rhs
135            }
136        }
137
138        #[inline(always)]
139        fn scalar_min(self, rhs: Self) -> Self {
140            if self <= rhs {
141                self
142            } else {
143                rhs
144            }
145        }
146    };
147}
148
149/// Narrow / unsigned integers use `MIN`/`MAX` as the tropical zero and keep the
150/// default (`false`) drift hooks — they are out of the headroom-sentinel scheme,
151/// so a guard-free `zero ⊗ zero` (`MIN + MIN`) can overflow. These types are not
152/// wired into `KernelDispatch`, so they are only reachable through the low-level
153/// core kernels, not the public `tropical_matmul*` API.
154macro_rules! impl_tropical_scalar_int {
155    ($($t:ty),*) => {
156        $(
157            impl TropicalScalar for $t {
158                tropical_int_common!();
159
160                #[inline(always)]
161                fn pos_infinity() -> Self {
162                    <$t>::MAX
163                }
164
165                #[inline(always)]
166                fn neg_infinity() -> Self {
167                    <$t>::MIN
168                }
169            }
170        )*
171    };
172}
173
174/// Wide signed integers (`i32`/`i64`) use a large *headroom* sentinel instead of
175/// `MIN`/`MAX`, so a guard-free `+` neither overflows on `zero ⊗ zero`
176/// (`±S + ±S` stays in range) nor collides with realistic data, and a drifted
177/// tropical zero is detectable by the `|value| >= |S|/2` threshold. This matches
178/// the CUDA backend (`NEG_INF_I32 = -1e9`, `NEG_INF_I64 = -(1 << 60)`).
179macro_rules! impl_tropical_scalar_int_wide {
180    ($($t:ty => ($neg:expr, $pos:expr)),* $(,)?) => {
181        $(
182            impl TropicalScalar for $t {
183                tropical_int_common!();
184
185                #[inline(always)]
186                fn pos_infinity() -> Self {
187                    $pos
188                }
189
190                #[inline(always)]
191                fn neg_infinity() -> Self {
192                    $neg
193                }
194
195                #[inline(always)]
196                fn is_drifted_neg_zero(self) -> bool {
197                    self <= $neg / 2
198                }
199
200                #[inline(always)]
201                fn is_drifted_pos_zero(self) -> bool {
202                    self >= $pos / 2
203                }
204            }
205        )*
206    };
207}
208
209impl_tropical_scalar_float!(f32, f64);
210impl_tropical_scalar_int!(i8, i16, u8, u16, u32, u64);
211impl_tropical_scalar_int_wide!(
212    i32 => (-1_000_000_000, 1_000_000_000),
213    i64 => (-(1i64 << 60), 1i64 << 60),
214);
215
216impl TropicalScalar for bool {
217    #[inline(always)]
218    fn scalar_zero() -> Self {
219        false
220    }
221
222    #[inline(always)]
223    fn scalar_one() -> Self {
224        true
225    }
226
227    #[inline(always)]
228    fn scalar_add(self, rhs: Self) -> Self {
229        self || rhs
230    }
231
232    #[inline(always)]
233    fn scalar_mul(self, rhs: Self) -> Self {
234        self && rhs
235    }
236
237    #[inline(always)]
238    fn pos_infinity() -> Self {
239        true
240    }
241
242    #[inline(always)]
243    fn neg_infinity() -> Self {
244        false
245    }
246
247    #[inline(always)]
248    fn scalar_max(self, rhs: Self) -> Self {
249        self || rhs
250    }
251
252    #[inline(always)]
253    fn scalar_min(self, rhs: Self) -> Self {
254        self && rhs
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_f64_scalar() {
264        assert_eq!(f64::scalar_zero(), 0.0);
265        assert_eq!(f64::scalar_one(), 1.0);
266        assert_eq!(3.0f64.scalar_add(5.0), 8.0);
267        assert_eq!(3.0f64.scalar_mul(5.0), 15.0);
268        assert!(f64::pos_infinity().is_infinite() && f64::pos_infinity() > 0.0);
269        assert!(f64::neg_infinity().is_infinite() && f64::neg_infinity() < 0.0);
270        assert_eq!(3.0f64.scalar_max(5.0), 5.0);
271        assert_eq!(3.0f64.scalar_min(5.0), 3.0);
272    }
273
274    #[test]
275    fn test_f32_scalar() {
276        assert_eq!(f32::scalar_zero(), 0.0);
277        assert_eq!(f32::scalar_one(), 1.0);
278        assert!((3.0f32.scalar_add(5.0) - 8.0).abs() < 1e-6);
279        assert!((3.0f32.scalar_mul(5.0) - 15.0).abs() < 1e-6);
280        assert!(f32::pos_infinity().is_infinite() && f32::pos_infinity() > 0.0);
281        assert!(f32::neg_infinity().is_infinite() && f32::neg_infinity() < 0.0);
282        assert!((3.0f32.scalar_max(5.0) - 5.0).abs() < 1e-6);
283        assert!((3.0f32.scalar_min(5.0) - 3.0).abs() < 1e-6);
284    }
285
286    #[test]
287    fn test_i32_scalar() {
288        assert_eq!(i32::scalar_zero(), 0);
289        assert_eq!(i32::scalar_one(), 1);
290        assert_eq!(3i32.scalar_add(5), 8);
291        assert_eq!(3i32.scalar_mul(5), 15);
292        // Headroom sentinel (not MIN/MAX): guard-free + stays in range, drift is
293        // detectable by threshold, and it matches the CUDA backend.
294        assert_eq!(i32::pos_infinity(), 1_000_000_000);
295        assert_eq!(i32::neg_infinity(), -1_000_000_000);
296        assert_eq!(3i32.scalar_max(5), 5);
297        assert_eq!(3i32.scalar_min(5), 3);
298    }
299
300    #[test]
301    fn test_i64_scalar() {
302        assert_eq!(i64::scalar_zero(), 0);
303        assert_eq!(i64::scalar_one(), 1);
304        assert_eq!(3i64.scalar_add(5), 8);
305        assert_eq!(3i64.scalar_mul(5), 15);
306        assert_eq!(i64::pos_infinity(), 1i64 << 60);
307        assert_eq!(i64::neg_infinity(), -(1i64 << 60));
308        assert_eq!(3i64.scalar_max(5), 5);
309        assert_eq!(3i64.scalar_min(5), 3);
310    }
311
312    #[test]
313    fn test_i8_scalar() {
314        assert_eq!(i8::scalar_zero(), 0);
315        assert_eq!(i8::scalar_one(), 1);
316        assert_eq!(3i8.scalar_add(5), 8);
317        assert_eq!(3i8.scalar_mul(5), 15);
318        assert_eq!(i8::pos_infinity(), i8::MAX);
319        assert_eq!(i8::neg_infinity(), i8::MIN);
320        assert_eq!(3i8.scalar_max(5), 5);
321        assert_eq!(3i8.scalar_min(5), 3);
322    }
323
324    #[test]
325    fn test_i16_scalar() {
326        assert_eq!(i16::scalar_zero(), 0);
327        assert_eq!(i16::scalar_one(), 1);
328        assert_eq!(3i16.scalar_add(5), 8);
329        assert_eq!(3i16.scalar_mul(5), 15);
330        assert_eq!(i16::pos_infinity(), i16::MAX);
331        assert_eq!(i16::neg_infinity(), i16::MIN);
332        assert_eq!(3i16.scalar_max(5), 5);
333        assert_eq!(3i16.scalar_min(5), 3);
334    }
335
336    #[test]
337    fn test_u8_scalar() {
338        assert_eq!(u8::scalar_zero(), 0);
339        assert_eq!(u8::scalar_one(), 1);
340        assert_eq!(3u8.scalar_add(5), 8);
341        assert_eq!(3u8.scalar_mul(5), 15);
342        assert_eq!(u8::pos_infinity(), u8::MAX);
343        assert_eq!(u8::neg_infinity(), u8::MIN);
344        assert_eq!(3u8.scalar_max(5), 5);
345        assert_eq!(3u8.scalar_min(5), 3);
346    }
347
348    #[test]
349    fn test_u16_scalar() {
350        assert_eq!(u16::scalar_zero(), 0);
351        assert_eq!(u16::scalar_one(), 1);
352        assert_eq!(3u16.scalar_add(5), 8);
353        assert_eq!(3u16.scalar_mul(5), 15);
354        assert_eq!(u16::pos_infinity(), u16::MAX);
355        assert_eq!(u16::neg_infinity(), u16::MIN);
356        assert_eq!(3u16.scalar_max(5), 5);
357        assert_eq!(3u16.scalar_min(5), 3);
358    }
359
360    #[test]
361    fn test_u32_scalar() {
362        assert_eq!(u32::scalar_zero(), 0);
363        assert_eq!(u32::scalar_one(), 1);
364        assert_eq!(3u32.scalar_add(5), 8);
365        assert_eq!(3u32.scalar_mul(5), 15);
366        assert_eq!(u32::pos_infinity(), u32::MAX);
367        assert_eq!(u32::neg_infinity(), u32::MIN);
368        assert_eq!(3u32.scalar_max(5), 5);
369        assert_eq!(3u32.scalar_min(5), 3);
370    }
371
372    #[test]
373    fn test_u64_scalar() {
374        assert_eq!(u64::scalar_zero(), 0);
375        assert_eq!(u64::scalar_one(), 1);
376        assert_eq!(3u64.scalar_add(5), 8);
377        assert_eq!(3u64.scalar_mul(5), 15);
378        assert_eq!(u64::pos_infinity(), u64::MAX);
379        assert_eq!(u64::neg_infinity(), u64::MIN);
380        assert_eq!(3u64.scalar_max(5), 5);
381        assert_eq!(3u64.scalar_min(5), 3);
382    }
383
384    #[test]
385    fn test_bool_scalar() {
386        assert!(!bool::scalar_zero());
387        assert!(bool::scalar_one());
388        // scalar_add is OR
389        assert!(true.scalar_add(false));
390        assert!(false.scalar_add(true));
391        assert!(!false.scalar_add(false));
392        assert!(true.scalar_add(true));
393        // scalar_mul is AND
394        assert!(!true.scalar_mul(false));
395        assert!(!false.scalar_mul(true));
396        assert!(!false.scalar_mul(false));
397        assert!(true.scalar_mul(true));
398        // pos_infinity is true, neg_infinity is false
399        assert!(bool::pos_infinity());
400        assert!(!bool::neg_infinity());
401        // scalar_max is OR
402        assert!(true.scalar_max(false));
403        assert!(!false.scalar_max(false));
404        // scalar_min is AND
405        assert!(!true.scalar_min(false));
406        assert!(true.scalar_min(true));
407    }
408
409    #[test]
410    fn test_float_edge_cases() {
411        // Test max/min with equal values
412        assert_eq!(5.0f64.scalar_max(5.0), 5.0);
413        assert_eq!(5.0f64.scalar_min(5.0), 5.0);
414        assert_eq!(5.0f32.scalar_max(5.0), 5.0);
415        assert_eq!(5.0f32.scalar_min(5.0), 5.0);
416    }
417
418    #[test]
419    fn test_int_edge_cases() {
420        // Test max/min with equal values
421        assert_eq!(5i32.scalar_max(5), 5);
422        assert_eq!(5i32.scalar_min(5), 5);
423        // Test with negative numbers
424        assert_eq!((-3i32).scalar_max(-5), -3);
425        assert_eq!((-3i32).scalar_min(-5), -5);
426    }
427
428    #[test]
429    fn test_drifted_zero_detection() {
430        // The sentinel itself and anything in "infinity territory" (past S/2) is
431        // a drifted tropical zero; realistic data and the multiplicative one are
432        // not. Threshold = ±5e8 for i32, ±2^59 for i64.
433        assert!(i32::neg_infinity().is_drifted_neg_zero());
434        assert!((i32::neg_infinity() + 1000).is_drifted_neg_zero()); // drifted, still in territory
435        assert!(i32::pos_infinity().is_drifted_pos_zero());
436        assert!((i32::pos_infinity() - 1000).is_drifted_pos_zero());
437        // Realistic values are never mistaken for the zero.
438        assert!(!0i32.is_drifted_neg_zero());
439        assert!(!0i32.is_drifted_pos_zero());
440        assert!(!123_456i32.is_drifted_neg_zero());
441        assert!(!(-123_456i32).is_drifted_neg_zero());
442
443        assert!(i64::neg_infinity().is_drifted_neg_zero());
444        assert!(i64::pos_infinity().is_drifted_pos_zero());
445        assert!(!0i64.is_drifted_neg_zero());
446        assert!(!1_000_000_000_000i64.is_drifted_neg_zero());
447
448        // Floats never "drift" — exact ±∞, so the hook stays false (default).
449        assert!(!f64::neg_infinity().is_drifted_neg_zero());
450        assert!(!f64::pos_infinity().is_drifted_pos_zero());
451
452        // Narrow / unsigned ints keep MIN/MAX and the default false hook.
453        assert!(!i8::neg_infinity().is_drifted_neg_zero());
454        assert!(!u32::pos_infinity().is_drifted_pos_zero());
455    }
456}