tropical_gemm/types/traits.rs
1use super::scalar::TropicalScalar;
2use std::fmt::Debug;
3
4/// Core trait for tropical semiring operations.
5///
6/// A semiring (S, ⊕, ⊗) satisfies:
7/// - (S, ⊕) is a commutative monoid with identity `tropical_zero`
8/// - (S, ⊗) is a monoid with identity `tropical_one`
9/// - ⊗ distributes over ⊕
10/// - `tropical_zero` is absorbing: a ⊗ 0 = 0 ⊗ a = 0
11pub trait TropicalSemiring: Copy + Clone + Send + Sync + Debug + PartialEq + 'static {
12 /// The underlying scalar type.
13 type Scalar: TropicalScalar;
14
15 /// Returns the additive identity (zero element for ⊕).
16 fn tropical_zero() -> Self;
17
18 /// Returns the multiplicative identity (one element for ⊗).
19 fn tropical_one() -> Self;
20
21 /// Tropical addition (⊕).
22 fn tropical_add(self, rhs: Self) -> Self;
23
24 /// Tropical multiplication (⊗).
25 fn tropical_mul(self, rhs: Self) -> Self;
26
27 /// Get the underlying scalar value.
28 fn value(&self) -> Self::Scalar;
29
30 /// Create from a scalar value.
31 fn from_scalar(s: Self::Scalar) -> Self;
32}
33
34/// Extension trait for tropical types that support argmax tracking.
35///
36/// This is used for backpropagation: during matrix multiplication,
37/// we track which k index produced the optimal value for each C[i,j].
38pub trait TropicalWithArgmax: TropicalSemiring {
39 /// The index type used for argmax tracking.
40 type Index: Copy + Default + Debug + Send + Sync + 'static;
41
42 /// Tropical addition with argmax tracking.
43 ///
44 /// Returns the result of `tropical_add` along with the index
45 /// corresponding to which operand "won" (produced the result).
46 fn tropical_add_argmax(
47 self,
48 self_idx: Self::Index,
49 rhs: Self,
50 rhs_idx: Self::Index,
51 ) -> (Self, Self::Index);
52
53 /// Whether this (output) value is a tropical-zero "no contribution" cell
54 /// whose argmax index should be canonicalized at GEMM write-back.
55 ///
56 /// Integer tropical zeros use a guard-free in-band sentinel, so a
57 /// no-contribution cell's value drifts and its accumulated argmax adopts a
58 /// spurious `k`. Returning `true` lets the kernel reset that index to the
59 /// deterministic seed (`0`) so the whole repo agrees on one value for such
60 /// cells (and the backward pass routes no gradient there once that seed
61 /// becomes `-1`, a later step).
62 ///
63 /// Default `false`: exact-infinity types (floats) don't drift — their zero
64 /// cells already keep the seed — so the branch folds away for them.
65 #[inline(always)]
66 fn is_no_contribution(&self) -> bool {
67 false
68 }
69}
70
71/// Marker trait for tropical types that support SIMD acceleration.
72pub trait SimdTropical: TropicalSemiring {
73 /// Whether SIMD operations are available for this type.
74 const SIMD_AVAILABLE: bool;
75
76 /// The SIMD width in elements.
77 const SIMD_WIDTH: usize;
78}