tropical_gemm/types/traits.rs
1use super::scalar::TropicalScalar;
2use std::fmt::Debug;
3
4/// Core trait for tropical semiring operations.
5///
6/// A semiring (S, ⊕, ⊗) satisfies:
7/// - (S, ⊕) is a commutative monoid with identity `tropical_zero`
8/// - (S, ⊗) is a monoid with identity `tropical_one`
9/// - ⊗ distributes over ⊕
10/// - `tropical_zero` is absorbing: a ⊗ 0 = 0 ⊗ a = 0
11pub trait TropicalSemiring: Copy + Clone + Send + Sync + Debug + PartialEq + 'static {
12 /// The underlying scalar type.
13 type Scalar: TropicalScalar;
14
15 /// Returns the additive identity (zero element for ⊕).
16 fn tropical_zero() -> Self;
17
18 /// Returns the multiplicative identity (one element for ⊗).
19 fn tropical_one() -> Self;
20
21 /// Tropical addition (⊕).
22 fn tropical_add(self, rhs: Self) -> Self;
23
24 /// Tropical multiplication (⊗).
25 fn tropical_mul(self, rhs: Self) -> Self;
26
27 /// Get the underlying scalar value.
28 fn value(&self) -> Self::Scalar;
29
30 /// Create from a scalar value.
31 fn from_scalar(s: Self::Scalar) -> Self;
32}
33
34/// Extension trait for tropical types that support argmax tracking.
35///
36/// This is used for backpropagation: during matrix multiplication,
37/// we track which k index produced the optimal value for each C[i,j].
38pub trait TropicalWithArgmax: TropicalSemiring {
39 /// The index type used for argmax tracking.
40 type Index: Copy + Default + Debug + Send + Sync + 'static;
41
42 /// Tropical addition with argmax tracking.
43 ///
44 /// Returns the result of `tropical_add` along with the index
45 /// corresponding to which operand "won" (produced the result).
46 fn tropical_add_argmax(
47 self,
48 self_idx: Self::Index,
49 rhs: Self,
50 rhs_idx: Self::Index,
51 ) -> (Self, Self::Index);
52}
53
54/// Marker trait for tropical types that support SIMD acceleration.
55pub trait SimdTropical: TropicalSemiring {
56 /// Whether SIMD operations are available for this type.
57 const SIMD_AVAILABLE: bool;
58
59 /// The SIMD width in elements.
60 const SIMD_WIDTH: usize;
61}