omeinsum/algebra/semiring.rs
1//! Core algebraic traits for tensor operations.
2
3use super::Scalar;
4
5/// A semiring defines two binary operations (⊕, ⊗) with identities.
6///
7/// # Semiring Laws
8///
9/// For a semiring (S, ⊕, ⊗, 0, 1):
10/// - (S, ⊕, 0) is a commutative monoid
11/// - (S, ⊗, 1) is a monoid
12/// - ⊗ distributes over ⊕
13/// - 0 annihilates: a ⊗ 0 = 0 ⊗ a = 0
14///
15/// # Examples
16///
17/// | Semiring | ⊕ | ⊗ | 0 | 1 |
18/// |----------|---|---|---|---|
19/// | Standard | + | × | 0 | 1 |
20/// | MaxPlus | max | + | -∞ | 0 |
21/// | MinPlus | min | + | +∞ | 0 |
22/// | MaxMul | max | × | 0 | 1 |
23pub trait Semiring: Copy + Clone + Send + Sync + 'static {
24 /// The underlying scalar type
25 type Scalar: Scalar;
26
27 /// Additive identity (zero element for ⊕)
28 fn zero() -> Self;
29
30 /// Multiplicative identity (one element for ⊗)
31 fn one() -> Self;
32
33 /// Addition operation (⊕)
34 fn add(self, rhs: Self) -> Self;
35
36 /// Multiplication operation (⊗)
37 fn mul(self, rhs: Self) -> Self;
38
39 /// Create from scalar value
40 fn from_scalar(s: Self::Scalar) -> Self;
41
42 /// Extract scalar value
43 fn to_scalar(self) -> Self::Scalar;
44
45 /// Check if this is the zero element
46 fn is_zero(&self) -> bool;
47}
48
49/// Extended semiring operations for automatic differentiation.
50///
51/// This trait adds argmax tracking needed for tropical backpropagation.
52pub trait Algebra: Semiring {
53 /// Index type for argmax tracking
54 type Index: Copy + Clone + Send + Sync + Default + std::fmt::Debug + 'static;
55
56 /// Addition with argmax tracking.
57 ///
58 /// Returns (result, winner_index) where winner_index indicates
59 /// which operand "won" the addition (relevant for tropical max/min).
60 fn add_with_argmax(
61 self,
62 self_idx: Self::Index,
63 rhs: Self,
64 rhs_idx: Self::Index,
65 ) -> (Self, Self::Index);
66
67 /// Backward pass for addition.
68 ///
69 /// Given output gradient `grad_out`, compute gradients for inputs.
70 /// For standard arithmetic: both inputs get `grad_out`.
71 /// For tropical: only the winner gets `grad_out`.
72 fn add_backward(
73 self,
74 rhs: Self,
75 grad_out: Self::Scalar,
76 winner_idx: Option<Self::Index>,
77 ) -> (Self::Scalar, Self::Scalar);
78
79 /// Backward pass for multiplication.
80 ///
81 /// Given output gradient `grad_out`, compute gradients for inputs.
82 /// Standard: grad_a = grad_out × b, grad_b = grad_out × a
83 /// Tropical (add): grad_a = grad_out, grad_b = grad_out
84 fn mul_backward(self, rhs: Self, grad_out: Self::Scalar) -> (Self::Scalar, Self::Scalar);
85
86 /// Whether this algebra requires argmax tracking for backprop.
87 fn needs_argmax() -> bool {
88 false
89 }
90
91 /// Check if `self` is "better" than `other` for tropical selection.
92 ///
93 /// For MaxPlus: returns true if self > other
94 /// For MinPlus: returns true if self < other
95 /// For Standard: not meaningful (always false)
96 fn is_better(&self, other: &Self) -> bool;
97}