omeinsum/algebra/
standard.rs

1//! Standard arithmetic semiring `(+, ×)`.
2
3use super::semiring::{Algebra, Semiring};
4use super::Scalar;
5use num_traits::{One, Zero};
6
7/// Standard arithmetic semiring with addition and multiplication.
8///
9/// This represents the usual `(+, ×)` operations used in linear algebra.
10///
11/// # Example
12///
13/// ```rust
14/// use omeinsum::algebra::{Standard, Semiring};
15///
16/// let a = Standard(2.0f32);
17/// let b = Standard(3.0f32);
18///
19/// assert_eq!(a.add(b).to_scalar(), 5.0);  // 2 + 3 = 5
20/// assert_eq!(a.mul(b).to_scalar(), 6.0);  // 2 × 3 = 6
21/// ```
22#[derive(Debug, Clone, Copy, PartialEq)]
23#[repr(transparent)]
24pub struct Standard<T: Scalar>(pub T);
25
26impl<
27        T: Scalar + Zero + One + PartialEq + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
28    > Semiring for Standard<T>
29{
30    type Scalar = T;
31
32    #[inline]
33    fn zero() -> Self {
34        Standard(T::zero())
35    }
36
37    #[inline]
38    fn one() -> Self {
39        Standard(T::one())
40    }
41
42    #[inline]
43    fn add(self, rhs: Self) -> Self {
44        Standard(self.0 + rhs.0)
45    }
46
47    #[inline]
48    fn mul(self, rhs: Self) -> Self {
49        Standard(self.0 * rhs.0)
50    }
51
52    #[inline]
53    fn from_scalar(s: T) -> Self {
54        Standard(s)
55    }
56
57    #[inline]
58    fn to_scalar(self) -> T {
59        self.0
60    }
61
62    #[inline]
63    fn is_zero(&self) -> bool {
64        self.0 == T::zero()
65    }
66}
67
68impl<
69        T: Scalar + Zero + One + PartialEq + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
70    > Algebra for Standard<T>
71{
72    type Index = u32;
73
74    #[inline]
75    fn add_with_argmax(
76        self,
77        _self_idx: Self::Index,
78        rhs: Self,
79        _rhs_idx: Self::Index,
80    ) -> (Self, Self::Index) {
81        // Standard addition doesn't track argmax
82        (self.add(rhs), 0)
83    }
84
85    #[inline]
86    fn add_backward(
87        self,
88        _rhs: Self,
89        grad_out: Self::Scalar,
90        _winner_idx: Option<Self::Index>,
91    ) -> (Self::Scalar, Self::Scalar) {
92        // Standard addition: both inputs get the full gradient
93        (grad_out, grad_out)
94    }
95
96    #[inline]
97    fn mul_backward(self, rhs: Self, grad_out: Self::Scalar) -> (Self::Scalar, Self::Scalar) {
98        // Standard multiplication: chain rule
99        // d/da (a × b) = b, d/db (a × b) = a
100        (
101            Standard(grad_out).mul(rhs).to_scalar(),
102            Standard(grad_out).mul(self).to_scalar(),
103        )
104    }
105
106    #[inline]
107    fn needs_argmax() -> bool {
108        false
109    }
110
111    #[inline]
112    fn is_better(&self, _other: &Self) -> bool {
113        // Standard algebra accumulates all values, no "better" comparison
114        false
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_standard_f32() {
124        let a = Standard(2.0f32);
125        let b = Standard(3.0f32);
126
127        assert_eq!(a.add(b).to_scalar(), 5.0);
128        assert_eq!(a.mul(b).to_scalar(), 6.0);
129        assert_eq!(Standard::<f32>::zero().to_scalar(), 0.0);
130        assert_eq!(Standard::<f32>::one().to_scalar(), 1.0);
131    }
132
133    #[test]
134    fn test_standard_backward() {
135        let a = Standard(2.0f32);
136        let b = Standard(3.0f32);
137
138        // Add backward
139        let (ga, gb) = a.add_backward(b, 1.0, None);
140        assert_eq!(ga, 1.0);
141        assert_eq!(gb, 1.0);
142
143        // Mul backward: d/da(a*b) = b, d/db(a*b) = a
144        let (ga, gb) = a.mul_backward(b, 1.0);
145        assert_eq!(ga, 3.0); // grad_out * b
146        assert_eq!(gb, 2.0); // grad_out * a
147    }
148}