omeinsum/algebra/
semiring.rs

1//! Core algebraic traits for tensor operations.
2
3use super::Scalar;
4
5/// A semiring defines two binary operations (⊕, ⊗) with identities.
6///
7/// # Semiring Laws
8///
9/// For a semiring (S, ⊕, ⊗, 0, 1):
10/// - (S, ⊕, 0) is a commutative monoid
11/// - (S, ⊗, 1) is a monoid
12/// - ⊗ distributes over ⊕
13/// - 0 annihilates: a ⊗ 0 = 0 ⊗ a = 0
14///
15/// # Examples
16///
17/// | Semiring | ⊕ | ⊗ | 0 | 1 |
18/// |----------|---|---|---|---|
19/// | Standard | + | × | 0 | 1 |
20/// | MaxPlus  | max | + | -∞ | 0 |
21/// | MinPlus  | min | + | +∞ | 0 |
22/// | MaxMul   | max | × | 0 | 1 |
23pub trait Semiring: Copy + Clone + Send + Sync + 'static {
24    /// The underlying scalar type
25    type Scalar: Scalar;
26
27    /// Additive identity (zero element for ⊕)
28    fn zero() -> Self;
29
30    /// Multiplicative identity (one element for ⊗)
31    fn one() -> Self;
32
33    /// Addition operation (⊕)
34    fn add(self, rhs: Self) -> Self;
35
36    /// Multiplication operation (⊗)
37    fn mul(self, rhs: Self) -> Self;
38
39    /// Create from scalar value
40    fn from_scalar(s: Self::Scalar) -> Self;
41
42    /// Extract scalar value
43    fn to_scalar(self) -> Self::Scalar;
44
45    /// Check if this is the zero element
46    fn is_zero(&self) -> bool;
47}
48
49/// Extended semiring operations for automatic differentiation.
50///
51/// This trait adds argmax tracking needed for tropical backpropagation.
52pub trait Algebra: Semiring {
53    /// Index type for argmax tracking
54    type Index: Copy + Clone + Send + Sync + Default + std::fmt::Debug + 'static;
55
56    /// Addition with argmax tracking.
57    ///
58    /// Returns (result, winner_index) where winner_index indicates
59    /// which operand "won" the addition (relevant for tropical max/min).
60    fn add_with_argmax(
61        self,
62        self_idx: Self::Index,
63        rhs: Self,
64        rhs_idx: Self::Index,
65    ) -> (Self, Self::Index);
66
67    /// Backward pass for addition.
68    ///
69    /// Given output gradient `grad_out`, compute gradients for inputs.
70    /// For standard arithmetic: both inputs get `grad_out`.
71    /// For tropical: only the winner gets `grad_out`.
72    fn add_backward(
73        self,
74        rhs: Self,
75        grad_out: Self::Scalar,
76        winner_idx: Option<Self::Index>,
77    ) -> (Self::Scalar, Self::Scalar);
78
79    /// Backward pass for multiplication.
80    ///
81    /// Given output gradient `grad_out`, compute gradients for inputs.
82    /// Standard: grad_a = grad_out × b, grad_b = grad_out × a
83    /// Tropical (add): grad_a = grad_out, grad_b = grad_out
84    fn mul_backward(self, rhs: Self, grad_out: Self::Scalar) -> (Self::Scalar, Self::Scalar);
85
86    /// Whether this algebra requires argmax tracking for backprop.
87    fn needs_argmax() -> bool {
88        false
89    }
90
91    /// Check if `self` is "better" than `other` for tropical selection.
92    ///
93    /// For MaxPlus: returns true if self > other
94    /// For MinPlus: returns true if self < other
95    /// For Standard: not meaningful (always false)
96    fn is_better(&self, other: &Self) -> bool;
97}