omeinsum/algebra/
standard.rs1use super::semiring::{Algebra, Semiring};
4use super::Scalar;
5use num_traits::{One, Zero};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
23#[repr(transparent)]
24pub struct Standard<T: Scalar>(pub T);
25
26impl<
27 T: Scalar + Zero + One + PartialEq + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
28 > Semiring for Standard<T>
29{
30 type Scalar = T;
31
32 #[inline]
33 fn zero() -> Self {
34 Standard(T::zero())
35 }
36
37 #[inline]
38 fn one() -> Self {
39 Standard(T::one())
40 }
41
42 #[inline]
43 fn add(self, rhs: Self) -> Self {
44 Standard(self.0 + rhs.0)
45 }
46
47 #[inline]
48 fn mul(self, rhs: Self) -> Self {
49 Standard(self.0 * rhs.0)
50 }
51
52 #[inline]
53 fn from_scalar(s: T) -> Self {
54 Standard(s)
55 }
56
57 #[inline]
58 fn to_scalar(self) -> T {
59 self.0
60 }
61
62 #[inline]
63 fn is_zero(&self) -> bool {
64 self.0 == T::zero()
65 }
66}
67
68impl<
69 T: Scalar + Zero + One + PartialEq + std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
70 > Algebra for Standard<T>
71{
72 type Index = u32;
73
74 #[inline]
75 fn add_with_argmax(
76 self,
77 _self_idx: Self::Index,
78 rhs: Self,
79 _rhs_idx: Self::Index,
80 ) -> (Self, Self::Index) {
81 (self.add(rhs), 0)
83 }
84
85 #[inline]
86 fn add_backward(
87 self,
88 _rhs: Self,
89 grad_out: Self::Scalar,
90 _winner_idx: Option<Self::Index>,
91 ) -> (Self::Scalar, Self::Scalar) {
92 (grad_out, grad_out)
94 }
95
96 #[inline]
97 fn mul_backward(self, rhs: Self, grad_out: Self::Scalar) -> (Self::Scalar, Self::Scalar) {
98 (
101 Standard(grad_out).mul(rhs).to_scalar(),
102 Standard(grad_out).mul(self).to_scalar(),
103 )
104 }
105
106 #[inline]
107 fn needs_argmax() -> bool {
108 false
109 }
110
111 #[inline]
112 fn is_better(&self, _other: &Self) -> bool {
113 false
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_standard_f32() {
124 let a = Standard(2.0f32);
125 let b = Standard(3.0f32);
126
127 assert_eq!(a.add(b).to_scalar(), 5.0);
128 assert_eq!(a.mul(b).to_scalar(), 6.0);
129 assert_eq!(Standard::<f32>::zero().to_scalar(), 0.0);
130 assert_eq!(Standard::<f32>::one().to_scalar(), 1.0);
131 }
132
133 #[test]
134 fn test_standard_backward() {
135 let a = Standard(2.0f32);
136 let b = Standard(3.0f32);
137
138 let (ga, gb) = a.add_backward(b, 1.0, None);
140 assert_eq!(ga, 1.0);
141 assert_eq!(gb, 1.0);
142
143 let (ga, gb) = a.mul_backward(b, 1.0);
145 assert_eq!(ga, 3.0); assert_eq!(gb, 2.0); }
148}