Skip to main content

tropical_gemm/types/
traits.rs

1use super::scalar::TropicalScalar;
2use std::fmt::Debug;
3
4/// Core trait for tropical semiring operations.
5///
6/// A semiring (S, ⊕, ⊗) satisfies:
7/// - (S, ⊕) is a commutative monoid with identity `tropical_zero`
8/// - (S, ⊗) is a monoid with identity `tropical_one`
9/// - ⊗ distributes over ⊕
10/// - `tropical_zero` is absorbing: a ⊗ 0 = 0 ⊗ a = 0
11pub trait TropicalSemiring: Copy + Clone + Send + Sync + Debug + PartialEq + 'static {
12    /// The underlying scalar type.
13    type Scalar: TropicalScalar;
14
15    /// Returns the additive identity (zero element for ⊕).
16    fn tropical_zero() -> Self;
17
18    /// Returns the multiplicative identity (one element for ⊗).
19    fn tropical_one() -> Self;
20
21    /// Tropical addition (⊕).
22    fn tropical_add(self, rhs: Self) -> Self;
23
24    /// Tropical multiplication (⊗).
25    fn tropical_mul(self, rhs: Self) -> Self;
26
27    /// Get the underlying scalar value.
28    fn value(&self) -> Self::Scalar;
29
30    /// Create from a scalar value.
31    fn from_scalar(s: Self::Scalar) -> Self;
32}
33
34/// Extension trait for tropical types that support argmax tracking.
35///
36/// This is used for backpropagation: during matrix multiplication,
37/// we track which k index produced the optimal value for each C[i,j].
38pub trait TropicalWithArgmax: TropicalSemiring {
39    /// The index type used for argmax tracking.
40    type Index: Copy + Default + Debug + Send + Sync + 'static;
41
42    /// Tropical addition with argmax tracking.
43    ///
44    /// Returns the result of `tropical_add` along with the index
45    /// corresponding to which operand "won" (produced the result).
46    fn tropical_add_argmax(
47        self,
48        self_idx: Self::Index,
49        rhs: Self,
50        rhs_idx: Self::Index,
51    ) -> (Self, Self::Index);
52
53    /// Whether this (output) value is a tropical-zero "no contribution" cell
54    /// whose argmax index should be canonicalized at GEMM write-back.
55    ///
56    /// Integer tropical zeros use a guard-free in-band sentinel, so a
57    /// no-contribution cell's value drifts and its accumulated argmax adopts a
58    /// spurious `k`. Returning `true` lets the kernel reset that index to the
59    /// deterministic seed (`0`) so the whole repo agrees on one value for such
60    /// cells (and the backward pass routes no gradient there once that seed
61    /// becomes `-1`, a later step).
62    ///
63    /// Default `false`: exact-infinity types (floats) don't drift — their zero
64    /// cells already keep the seed — so the branch folds away for them.
65    #[inline(always)]
66    fn is_no_contribution(&self) -> bool {
67        false
68    }
69}
70
71/// Marker trait for tropical types that support SIMD acceleration.
72pub trait SimdTropical: TropicalSemiring {
73    /// Whether SIMD operations are available for this type.
74    const SIMD_AVAILABLE: bool;
75
76    /// The SIMD width in elements.
77    const SIMD_WIDTH: usize;
78}