tropical_gemm/types/
traits.rs

1use super::scalar::TropicalScalar;
2use std::fmt::Debug;
3
4/// Core trait for tropical semiring operations.
5///
6/// A semiring (S, ⊕, ⊗) satisfies:
7/// - (S, ⊕) is a commutative monoid with identity `tropical_zero`
8/// - (S, ⊗) is a monoid with identity `tropical_one`
9/// - ⊗ distributes over ⊕
10/// - `tropical_zero` is absorbing: a ⊗ 0 = 0 ⊗ a = 0
11pub trait TropicalSemiring: Copy + Clone + Send + Sync + Debug + PartialEq + 'static {
12    /// The underlying scalar type.
13    type Scalar: TropicalScalar;
14
15    /// Returns the additive identity (zero element for ⊕).
16    fn tropical_zero() -> Self;
17
18    /// Returns the multiplicative identity (one element for ⊗).
19    fn tropical_one() -> Self;
20
21    /// Tropical addition (⊕).
22    fn tropical_add(self, rhs: Self) -> Self;
23
24    /// Tropical multiplication (⊗).
25    fn tropical_mul(self, rhs: Self) -> Self;
26
27    /// Get the underlying scalar value.
28    fn value(&self) -> Self::Scalar;
29
30    /// Create from a scalar value.
31    fn from_scalar(s: Self::Scalar) -> Self;
32}
33
34/// Extension trait for tropical types that support argmax tracking.
35///
36/// This is used for backpropagation: during matrix multiplication,
37/// we track which k index produced the optimal value for each C[i,j].
38pub trait TropicalWithArgmax: TropicalSemiring {
39    /// The index type used for argmax tracking.
40    type Index: Copy + Default + Debug + Send + Sync + 'static;
41
42    /// Tropical addition with argmax tracking.
43    ///
44    /// Returns the result of `tropical_add` along with the index
45    /// corresponding to which operand "won" (produced the result).
46    fn tropical_add_argmax(
47        self,
48        self_idx: Self::Index,
49        rhs: Self,
50        rhs_idx: Self::Index,
51    ) -> (Self, Self::Index);
52}
53
54/// Marker trait for tropical types that support SIMD acceleration.
55pub trait SimdTropical: TropicalSemiring {
56    /// Whether SIMD operations are available for this type.
57    const SIMD_AVAILABLE: bool;
58
59    /// The SIMD width in elements.
60    const SIMD_WIDTH: usize;
61}